# Truncate Cell Embedding using Only Top N Genes

In [2]:
import scanpy as sc

adata_12k = sc.read_h5ad('../pbmc/pre_processed/pbmc_12k.h5ad')
print(adata_12k)

AnnData object with n_obs × n_vars = 11990 × 3346
    obs: 'n_counts', 'batch', 'labels', 'str_labels', 'cell_type', 'batch_key', 'original_n_counts'
    var: 'key_0', 'gene_symbols', 'n_counts-0', 'n_counts-1', 'n_counts', 'ensembl_id', 'gene_symbol', 'entrez_id', 'refseq_id'
    uns: 'cell_types'
    obsm: 'design', 'normalized_qc', 'qc_pc', 'raw_qc'




In [3]:
adata_59k = sc.read_h5ad('../pbmc/pre_processed/pbmc_59k.h5ad')
print(adata_59k)

AnnData object with n_obs × n_vars = 59506 × 23948
    obs: 'cluster', 'n_features', 'mito_pct', 'Annotation', 'rank', 'donor_id', 'time_point', 'age', 'who_max', 'who_d0', 'who_d3', 'who_d7', 'who_d28', 'cardiacevent_72h', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'disease_ontology_term_id', 'preexisting_heartdisease', 'preexisting_lungdisease', 'preexisting_kidneydisease', 'preexisting_diabetes', 'preexisting_hypertension', 'preexisting_immunocompromisedcondition', 'respiratory_symptoms', 'fever_symptoms', 'gastrointestinal_symptoms', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'organism_ontology_term_id', 'sex_ontology_term_id', 'tissue_ontology_term_id', 'suspension_type', 'tissue_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', 'batch_key', 'n_counts'
    var: 'key_0', 'author_feature_name', 'feature_is_filtered', 'feature



In [4]:
adata_67k = sc.read_h5ad('../pbmc/pre_processed/pbmc_67k.h5ad')
print(adata_67k)

AnnData object with n_obs × n_vars = 66985 × 36263
    obs: 'nCount_RNA', 'nFeature_RNA', 'nCount_HTO', 'nFeature_HTO', 'HTO_maxID', 'HTO_secondID', 'HTO_margin', 'HTO_classification.global', 'sample', 'donor_id', 'CHIP', 'LANE', 'ProjectID', 'MUTATION', 'MUTATION.GROUP', 'sex_ontology_term_id', 'HTOID', 'percent.mt', 'nCount_SCT', 'nFeature_SCT', 'scType_celltype', 'pANN', 'development_stage_ontology_term_id', 'cell_type_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'assay_ontology_term_id', 'suspension_type', 'is_primary_data', 'tissue_type', 'tissue_ontology_term_id', 'organism_ontology_term_id', 'disease_ontology_term_id', 'Clone', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', 'batch_key', 'n_counts'
    var: 'key_0', 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length', 'feature_type', 'ensembl_id', 'gene_symbol', 'entrez_id', 'refseq_id



In [25]:
import scanpy as sc
import os


def truncate_top_genes(adata, n):
    gene_expression_sum = adata.X.sum(axis=0).A1
    top_genes_idx = gene_expression_sum.argsort()[-n:][::-1]
    return adata[:, top_genes_idx]


directory = '../pbmc/pre_processed'
output_directory = '../pbmc_truncated_genes/pre_processed_200'
os.makedirs(output_directory, exist_ok=True)
for filename in os.listdir(directory):
    print(f'Processing {filename}')
    file_path = os.path.join(directory, filename)
    adata = sc.read_h5ad(file_path)
    adata = truncate_top_genes(adata, 200)
    adata.write_h5ad(os.path.join(output_directory, filename))
    print(adata)

Processing pbmc_12k.h5ad
View of AnnData object with n_obs × n_vars = 11990 × 200
    obs: 'n_counts', 'batch', 'labels', 'str_labels', 'cell_type', 'batch_key', 'original_n_counts'
    var: 'key_0', 'gene_symbols', 'n_counts-0', 'n_counts-1', 'n_counts', 'ensembl_id', 'gene_symbol', 'entrez_id', 'refseq_id'
    uns: 'cell_types'
    obsm: 'design', 'normalized_qc', 'qc_pc', 'raw_qc'
Processing pbmc_59k.h5ad
View of AnnData object with n_obs × n_vars = 59506 × 200
    obs: 'cluster', 'n_features', 'mito_pct', 'Annotation', 'rank', 'donor_id', 'time_point', 'age', 'who_max', 'who_d0', 'who_d3', 'who_d7', 'who_d28', 'cardiacevent_72h', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'disease_ontology_term_id', 'preexisting_heartdisease', 'preexisting_lungdisease', 'preexisting_kidneydisease', 'preexisting_diabetes', 'preexisting_hypertension', 'preexisting_immunocompromisedcondition', 'respiratory_symptoms', 'fever_symptoms', 'gastrointestina

In [29]:
# Generate cell embeddings with Geneformer.
"""
Geneformer configs
"""
geneformer_configs = dict(
    # The directory to store preprocessed data.
    preprocessed_data_directory='../pbmc_truncated_genes/pre_processed_500',
    # The output tokenized file directory.
    tokenized_file_directory="../pbmc_truncated_genes/tokenized_500",
    # The output tokenized filename prefix.
    tokenized_file_prefix='geneformer',
    # The output embedding file directory.
    embedding_output_directory="../pbmc_truncated_genes/",
    # The output embedding file name.
    embedding_output_filename="geneformer_cell_embeddings_truncated_top_500_genes",
    # Directory of the Geneformer pre-trained model.
    load_model_dir="../embedding_extractors/models/geneformer/model/",
    # List of cell attribute labels to keep, i.e. `cell_type` and `batch_key`. If none, use empty list [].
    custom_cell_attr_names=['cell_type', 'batch_key', 'n_counts'],
)

from embedding_extractors import EmbeddingExtractor

emb_extractor = EmbeddingExtractor("Geneformer", output_file_type='h5ad', configs=geneformer_configs)
emb_extractor.tokenize()
emb_extractor.extract_embeddings()

Tokenizing ..\pbmc_truncated_genes\pre_processed_500\pbmc_12k.h5ad
..\pbmc_truncated_genes\pre_processed_500\pbmc_12k.h5ad has no column attribute 'filter_pass'; tokenizing all cells.
Tokenizing ..\pbmc_truncated_genes\pre_processed_500\pbmc_59k.h5ad
..\pbmc_truncated_genes\pre_processed_500\pbmc_59k.h5ad has no column attribute 'filter_pass'; tokenizing all cells.
Tokenizing ..\pbmc_truncated_genes\pre_processed_500\pbmc_67k.h5ad
..\pbmc_truncated_genes\pre_processed_500\pbmc_67k.h5ad has no column attribute 'filter_pass'; tokenizing all cells.
Creating dataset.


Map: 100%|██████████| 138481/138481 [00:19<00:00, 7261.80 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 138481/138481 [00:00<00:00, 1798451.25 examples/s]


Tokenization completed for Geneformer.
Extracting Geneformer embeddings


100%|██████████| 13849/13849 [32:22<00:00,  7.13it/s] 


Output embedding in ../pbmc_truncated_genes/geneformer_cell_embeddings_truncated_top_500_genes.h5ad



In [30]:
# Generate cell embeddings with Geneformer.
"""
Geneformer configs
"""
geneformer_configs = dict(
    # The directory to store preprocessed data.
    preprocessed_data_directory='../pbmc_truncated_genes/pre_processed_1k',
    # The output tokenized file directory.
    tokenized_file_directory="../pbmc_truncated_genes/tokenized_1k",
    # The output tokenized filename prefix.
    tokenized_file_prefix='geneformer',
    # The output embedding file directory.
    embedding_output_directory="../pbmc_truncated_genes/",
    # The output embedding file name.
    embedding_output_filename="geneformer_cell_embeddings_truncated_top_1k_genes",
    # Directory of the Geneformer pre-trained model.
    load_model_dir="../embedding_extractors/models/geneformer/model/",
    # List of cell attribute labels to keep, i.e. `cell_type` and `batch_key`. If none, use empty list [].
    custom_cell_attr_names=['cell_type', 'batch_key', 'n_counts'],
)

from embedding_extractors import EmbeddingExtractor

emb_extractor = EmbeddingExtractor("Geneformer", output_file_type='h5ad', configs=geneformer_configs)
emb_extractor.tokenize()
emb_extractor.extract_embeddings()

Tokenizing ..\pbmc_truncated_genes\pre_processed_1k\pbmc_12k.h5ad
..\pbmc_truncated_genes\pre_processed_1k\pbmc_12k.h5ad has no column attribute 'filter_pass'; tokenizing all cells.
Tokenizing ..\pbmc_truncated_genes\pre_processed_1k\pbmc_59k.h5ad
..\pbmc_truncated_genes\pre_processed_1k\pbmc_59k.h5ad has no column attribute 'filter_pass'; tokenizing all cells.
Tokenizing ..\pbmc_truncated_genes\pre_processed_1k\pbmc_67k.h5ad
..\pbmc_truncated_genes\pre_processed_1k\pbmc_67k.h5ad has no column attribute 'filter_pass'; tokenizing all cells.
Creating dataset.


Map: 100%|██████████| 138481/138481 [00:36<00:00, 3791.86 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 138481/138481 [00:00<00:00, 1064392.27 examples/s]


Tokenization completed for Geneformer.
Extracting Geneformer embeddings


100%|██████████| 13849/13849 [56:15<00:00,  4.10it/s] 


Output embedding in ../pbmc_truncated_genes/geneformer_cell_embeddings_truncated_top_1k_genes.h5ad



In [31]:
# Generate cell embeddings with Geneformer.
"""
Geneformer configs
"""
geneformer_configs = dict(
    # The directory to store preprocessed data.
    preprocessed_data_directory='../pbmc_truncated_genes/pre_processed_3k',
    # The output tokenized file directory.
    tokenized_file_directory="../pbmc_truncated_genes/tokenized_3k",
    # The output tokenized filename prefix.
    tokenized_file_prefix='geneformer',
    # The output embedding file directory.
    embedding_output_directory="../pbmc_truncated_genes/",
    # The output embedding file name.
    embedding_output_filename="geneformer_cell_embeddings_truncated_top_3k_genes",
    # Directory of the Geneformer pre-trained model.
    load_model_dir="../embedding_extractors/models/geneformer/model/",
    # List of cell attribute labels to keep, i.e. `cell_type` and `batch_key`. If none, use empty list [].
    custom_cell_attr_names=['cell_type', 'batch_key', 'n_counts'],
)

from embedding_extractors import EmbeddingExtractor

emb_extractor = EmbeddingExtractor("Geneformer", output_file_type='h5ad', configs=geneformer_configs)
emb_extractor.tokenize()
emb_extractor.extract_embeddings()

Tokenizing ..\pbmc_truncated_genes\pre_processed_3k\pbmc_12k.h5ad


100%|██████████| 24/24 [00:00<00:00, 33.00it/s]


..\pbmc_truncated_genes\pre_processed_3k\pbmc_12k.h5ad has no column attribute 'filter_pass'; tokenizing all cells.
Tokenizing ..\pbmc_truncated_genes\pre_processed_3k\pbmc_59k.h5ad
..\pbmc_truncated_genes\pre_processed_3k\pbmc_59k.h5ad has no column attribute 'filter_pass'; tokenizing all cells.
Tokenizing ..\pbmc_truncated_genes\pre_processed_3k\pbmc_67k.h5ad
..\pbmc_truncated_genes\pre_processed_3k\pbmc_67k.h5ad has no column attribute 'filter_pass'; tokenizing all cells.
Creating dataset.


Map: 100%|██████████| 138481/138481 [00:53<00:00, 2607.52 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 138481/138481 [00:00<00:00, 723904.65 examples/s]


Tokenization completed for Geneformer.
Extracting Geneformer embeddings


100%|██████████| 13849/13849 [51:18<00:00,  4.50it/s]  


Output embedding in ../pbmc_truncated_genes/geneformer_cell_embeddings_truncated_top_3k_genes.h5ad



In [32]:
import pandas as pd
import glob
from pathlib import Path

directory = '../pbmc_truncated_genes/*.h5ad'
for file_path in glob.glob(directory):
    print(f'Processing {file_path}')
    cell_emb = sc.read_h5ad(file_path)
    print(cell_emb.obsm['X_Geneformer'].shape)
    output_df = pd.concat([pd.DataFrame(cell_emb.obs['batch_key']).reset_index(drop=True),
                           pd.DataFrame(cell_emb.obs['n_counts']).reset_index(drop=True),
                           pd.DataFrame(cell_emb.obsm['X_Geneformer']).reset_index(drop=True)], axis=1)
    output_df.to_csv(Path(file_path).with_suffix('.csv'))

Processing ../pbmc_truncated_genes\geneformer_cell_embeddings_truncated_top_1k_genes.h5ad
(138481, 896)
Processing ../pbmc_truncated_genes\geneformer_cell_embeddings_truncated_top_200_genes.h5ad
(138481, 896)
Processing ../pbmc_truncated_genes\geneformer_cell_embeddings_truncated_top_3k_genes.h5ad
(138481, 896)
Processing ../pbmc_truncated_genes\geneformer_cell_embeddings_truncated_top_500_genes.h5ad
(138481, 896)


In [10]:
import scanpy as sc

adata = sc.read_h5ad('../pbmc_truncated_genes/geneformer_cell_embeddings_truncated_top_1k_genes.h5ad')
print(adata.obs['n_counts'][:10])

0    2277.666992
1    2272.825439
2    2270.584229
3    2292.384521
4    2503.839111
5    2179.061768
6    2125.930176
7    2417.371338
8    2164.579590
9    2357.080078
Name: n_counts, dtype: float64


In [12]:
import scanpy as sc

adata = sc.read_h5ad('../pbmc_truncated_genes/pre_processed_1k/pbmc_12k.h5ad')
print(adata.obs['n_counts'][:100])

AAACCTGAGCTAGTGG-1    1573.0
AAACCTGCACATTAGC-1     892.0
AAACCTGCACTGTTAG-1    1389.0
AAACCTGCATAGTAAG-1    1487.0
AAACCTGCATGAACCT-1     948.0
                       ...  
AAATGCCTCAACACAC-1    1693.0
AAATGCCTCACCAGGC-1    1240.0
AAATGCCTCCAGAAGG-1    2059.0
AAATGCCTCCTGCCAT-1    1149.0
AACACGTAGAGGTTGC-1     937.0
Name: n_counts, Length: 100, dtype: float32


In [21]:
import scanpy as sc

adata = sc.read_h5ad('../pbmc/pre_processed/pbmc_59k.h5ad')
print(adata.obs['n_counts'][:10])

cell
batch1_5p_rna|AAACCTGAGAAACGAG-1    1900.751587
batch1_5p_rna|AAACCTGAGAGCTGCA-1    1592.616699
batch1_5p_rna|AAACCTGAGAGTTGGC-1    1869.868408
batch1_5p_rna|AAACCTGAGATCCCGC-1    1662.501343
batch1_5p_rna|AAACCTGAGCAAATCA-1    1745.158813
batch1_5p_rna|AAACCTGAGCCAACAG-1    2031.940063
batch1_5p_rna|AAACCTGAGCGAAGGG-1    1816.160400
batch1_5p_rna|AAACCTGAGCGTAATA-1    2110.902588
batch1_5p_rna|AAACCTGCAAGCCGCT-1    1967.065186
batch1_5p_rna|AAACCTGCACCTATCC-1    1947.370850
Name: n_counts, dtype: float32




In [25]:
import scanpy as sc

adata = sc.read_h5ad('../pbmc/raw_data/pbmc_59k.h5ad')
print(adata.X[0])
print(adata.X[0].toarray().sum())

<Compressed Sparse Row sparse matrix of dtype 'float32'
	with 2185 stored elements and shape (1, 23948)>
  Coords	Values
  (0, 27)	0.8788620233535767
  (0, 28)	0.8788620233535767
  (0, 54)	0.5053439736366272
  (0, 60)	0.5053439736366272
  (0, 63)	0.5053439736366272
  (0, 79)	0.5053439736366272
  (0, 88)	0.5053439736366272
  (0, 104)	0.5053439736366272
  (0, 118)	3.294059991836548
  (0, 128)	0.5053439736366272
  (0, 146)	0.5053439736366272
  (0, 151)	0.5053439736366272
  (0, 155)	1.1752899885177612
  (0, 208)	0.5053439736366272
  (0, 220)	0.5053439736366272
  (0, 248)	0.5053439736366272
  (0, 251)	0.5053439736366272
  (0, 262)	0.5053439736366272
  (0, 278)	0.5053439736366272
  (0, 295)	0.8788620233535767
  (0, 297)	1.4210699796676636
  (0, 318)	0.5053439736366272
  (0, 321)	1.1752899885177612
  (0, 333)	1.4210699796676636
  (0, 343)	0.5053439736366272
  :	:
  (0, 23754)	0.8788620233535767
  (0, 23783)	0.5053439736366272
  (0, 23788)	0.5053439736366272
  (0, 23802)	0.5053439736366272
  (

In [20]:
import scanpy as sc

adata = sc.read_h5ad('../pbmc_masked_20/pre_processed/pbmc_59k.h5ad')
print(adata.obs['n_counts'][:10])

cell
batch1_5p_rna|AAACCTGAGAAACGAG-1    1520.601196
batch1_5p_rna|AAACCTGAGAGCTGCA-1    1274.093140
batch1_5p_rna|AAACCTGAGAGTTGGC-1    1495.895020
batch1_5p_rna|AAACCTGAGATCCCGC-1    1330.000977
batch1_5p_rna|AAACCTGAGCAAATCA-1    1396.127319
batch1_5p_rna|AAACCTGAGCCAACAG-1    1625.552002
batch1_5p_rna|AAACCTGAGCGAAGGG-1    1452.928223
batch1_5p_rna|AAACCTGAGCGTAATA-1    1688.722168
batch1_5p_rna|AAACCTGCAAGCCGCT-1    1573.652344
batch1_5p_rna|AAACCTGCACCTATCC-1    1557.896729
Name: n_counts, dtype: float32




In [17]:
from datasets import load_from_disk
ds = load_from_disk('../pbmc_truncated_genes/tokenized_1k/geneformer.dataset').to_pandas()
print(ds)

                                                input_ids  \
0       [2, 17593, 15158, 13986, 11659, 218, 18068, 17...   
1       [2, 13986, 5075, 1124, 12868, 17593, 13598, 29...   
2       [2, 10773, 16620, 18113, 9814, 8412, 395, 1733...   
3       [2, 10773, 18113, 1357, 12558, 16134, 2509, 16...   
4       [2, 11797, 5670, 18458, 11993, 7742, 14225, 13...   
...                                                   ...   
138476  [2, 7655, 12482, 842, 1006, 1882, 14447, 11482...   
138477  [2, 3298, 5536, 3387, 11889, 4841, 3735, 2347,...   
138478  [2, 13928, 6725, 6650, 12482, 13797, 6245, 842...   
138479  [2, 18466, 16533, 7224, 3053, 3735, 10132, 878...   
138480  [2, 16924, 8667, 2552, 7093, 547, 11750, 10134...   

                                                cell_type batch_key  \
0                                             CD4 T cells  pmbc_12k   
1                                             CD4 T cells  pmbc_12k   
2                                         CD14+ Monocy

In [19]:
print(ds['n_counts'][:10])

0    1573.0
1     892.0
2    1389.0
3    1487.0
4     948.0
5    1678.0
6    1813.0
7    1396.0
8    1026.0
9    2072.0
Name: n_counts, dtype: float64
