In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
import pandas as pd
pd.set_option('display.max_rows', 1000)
import sys
from pathlib import Path

repo_dir = Path.cwd().parent.absolute()
sys.path.append(str(repo_dir))



In [3]:
from src.utils import setup_data_dir
from pathlib import Path
setup_data_dir()
data_dir = repo_dir / "data"

File already exists at /Users/rj/personal/GenePT-tools/data/GenePT_emebdding_v2.zip
Extracting files...
Extracting GenePT_emebdding_v2/
Skipping GenePT_emebdding_v2/NCBI_UniProt_summary_of_genes.json - already exists with same size
Skipping GenePT_emebdding_v2/GenePT_gene_embedding_ada_text.pickle - already exists with same size
Skipping GenePT_emebdding_v2/GenePT_gene_protein_embedding_model_3_text.pickle. - already exists with same size
Skipping GenePT_emebdding_v2/NCBI_summary_of_genes.json - already exists with same size
Extraction complete!
Setup finished!


In [4]:
import requests

dataset = "https://datasets.cellxgene.cziscience.com/10df7690-6d10-4029-a47e-0f071bb2df83.h5ad"
# dataset_id = "10df7690-6d10-4029-a47e-0f071bb2df83"

file_path = data_dir / "1m_cells.h5ad"  # adjust this path as needed



if not file_path.exists():
    response = requests.get(dataset, stream=True)
    with open(file_path, 'wb') as file:
        for chunk in response.iter_content(chunk_size=8192):
            if chunk:  # filter out keep-alive chunks
                file.write(chunk)

In [5]:
import h5py

with h5py.File(file_path, 'r') as f:
    # Look at the structure of the X group
    print("Contents of X group:", list(f['X'].keys()))
    
    # Look at obs and var to get dimensions
    print("\nContents of obs group:", list(f['obs'].keys()))
    print("Contents of var group:", list(f['var'].keys()))
    
    # If X contains a sparse matrix, it likely has 'data', 'indices', and 'indptr'
    if 'data' in f['X']:
        print("\nShape of X/data:", f['X']['data'].shape)
        print("Shape of X/indices:", f['X']['indices'].shape)
        print("Shape of X/indptr:", f['X']['indptr'].shape)

Contents of X group: ['data', 'indices', 'indptr']

Contents of obs group: ['10X_run', '_index', '_scvi_batch', '_scvi_labels', 'ambient_removal', 'anatomical_position', 'assay', 'assay_ontology_term_id', 'broad_cell_class', 'cdna_plate', 'cdna_well', 'cell_type', 'cell_type_ontology_term_id', 'compartment', 'development_stage', 'development_stage_ontology_term_id', 'disease', 'disease_ontology_term_id', 'donor_assay', 'donor_id', 'donor_method', 'donor_tissue', 'donor_tissue_assay', 'ethnicity_original', 'free_annotation', 'is_primary_data', 'library_plate', 'manually_annotated', 'method', 'n_genes_by_counts', 'notes', 'observation_joinid', 'organism', 'organism_ontology_term_id', 'pct_counts_ercc', 'pct_counts_mt', 'published_2022', 'replicate', 'sample_id', 'sample_number', 'scvi_leiden_donorassay_full', 'self_reported_ethnicity', 'self_reported_ethnicity_ontology_term_id', 'sex', 'sex_ontology_term_id', 'suspension_type', 'tissue', 'tissue_in_publication', 'tissue_ontology_term_id'

In [14]:
import numpy as np
from scipy import sparse


def load_subset_sparse(h5py_file, start_row=0, n_rows=None):
    """
    Load a subset of rows from the sparse matrix.
    
    Args:
        file_path: Path to h5ad file
        start_row: Starting row index
        n_rows: Number of rows to load
    
    Returns:
        scipy.sparse.csr_matrix with the requested rows
    """
    with h5py.File(file_path, 'r') as f:
        # Get the indptr for the rows we want
        if n_rows is None:
            n_rows = len(f['X']['indptr']) - 1 - start_row

        indptr = f['X']['indptr'][start_row:start_row + n_rows + 1]
        # Find the indices in data array for our rows
        start_idx = indptr[0]
        end_idx = indptr[-1]
        
        # Load the relevant parts of the data and indices
        data = f['X']['data'][start_idx:end_idx]
        indices = f['X']['indices'][start_idx:end_idx]
        
        # Adjust indptr to start at 0
        indptr = indptr - start_idx
        
        # Get the total number of columns from the var group
        n_cols = len(f['var']['feature_name']['categories'])
        
        # Create the sparse matrix
        return sparse.csr_matrix((data, indices, indptr), shape=(n_rows, n_cols))

cell_gene_matrix = load_subset_sparse(file_path, start_row=0, n_rows=100000)
print("Matrix shape:", cell_gene_matrix.shape)
print("Matrix density:", cell_gene_matrix.nnz / (cell_gene_matrix.shape[0] * cell_gene_matrix.shape[1]))

Matrix shape: (100000, 61759)
Matrix density: 0.045159779789180524


In [19]:
with h5py.File(file_path, 'r') as f:
    print(f['var']['feature_name']["categories"])
    # print(f['X']['indices'][:10])

<HDF5 dataset "categories": shape (61759,), type "|O">


In [20]:
cell_gene_matrix

<Compressed Sparse Row sparse matrix of dtype 'float32'
	with 278902284 stored elements and shape (100000, 61759)>

In [77]:
with h5py.File(file_path, 'r') as f:
    # Get the indptr for the rows we want
    print(len(f['obs']['scvi_leiden_donorassay_full']['codes']))
    
    

1136218


In [78]:
with h5py.File(file_path, 'r') as f:
    # Get the indptr for the rows we want
    gene_names = f['var']['feature_name']
    ensembl_ids = f['var']['ensembl_id']
    scvi_leiden_donorassay_full = f['obs']['scvi_leiden_donorassay_full']['codes']
    major_ensembl_ids = pd.Series(
        ensembl_id.decode('utf-8').split('.')[0]
        for ensembl_id in ensembl_ids
    )
    

In [23]:
import h5py

with h5py.File(file_path, 'r') as f:
    # Look at the structure of obs group in detail
    obs_group = f['obs']
    print("Type of obs group:", type(obs_group))
    print("Keys in obs group:", list(obs_group.keys()))
    
    # Let's look at one specific column to understand its structure
    cell_type_data = obs_group['cell_type']
    print("\nType of cell_type data:", type(cell_type_data))
    if hasattr(cell_type_data, 'shape'):
        print("Shape of cell_type data:", cell_type_data.shape)

Type of obs group: <class 'h5py._hl.group.Group'>
Keys in obs group: ['10X_run', '_index', '_scvi_batch', '_scvi_labels', 'ambient_removal', 'anatomical_position', 'assay', 'assay_ontology_term_id', 'broad_cell_class', 'cdna_plate', 'cdna_well', 'cell_type', 'cell_type_ontology_term_id', 'compartment', 'development_stage', 'development_stage_ontology_term_id', 'disease', 'disease_ontology_term_id', 'donor_assay', 'donor_id', 'donor_method', 'donor_tissue', 'donor_tissue_assay', 'ethnicity_original', 'free_annotation', 'is_primary_data', 'library_plate', 'manually_annotated', 'method', 'n_genes_by_counts', 'notes', 'observation_joinid', 'organism', 'organism_ontology_term_id', 'pct_counts_ercc', 'pct_counts_mt', 'published_2022', 'replicate', 'sample_id', 'sample_number', 'scvi_leiden_donorassay_full', 'self_reported_ethnicity', 'self_reported_ethnicity_ontology_term_id', 'sex', 'sex_ontology_term_id', 'suspension_type', 'tissue', 'tissue_in_publication', 'tissue_ontology_term_id', 'tis

In [24]:
from datasets import load_dataset

gene_info_table_dataset = load_dataset("honicky/genept-composable-embeddings-source-data", "gene_info")
gene_info_table = gene_info_table_dataset['train'].to_pandas()
gene_info_table.head()


  from .autonotebook import tqdm as notebook_tqdm


Unnamed: 0,gene_name,ensembl_id,gene_type
0,TSPAN6,ENSG00000000003,protein_coding
1,TNMD,ENSG00000000005,protein_coding
2,DPM1,ENSG00000000419,protein_coding
3,SCYL3,ENSG00000000457,protein_coding
4,C1orf112,ENSG00000000460,protein_coding


In [25]:
gene_embedding = pd.read_parquet(data_dir / "generated/embeddings/embedding_associations_age_drugs_pathways_openai_large.parquet")


# Embed using the mean for duplicate Ensembl IDs


In [26]:
gene_embeddings_with_ensembl_id = gene_embedding.merge(gene_info_table, left_index=True, right_on='gene_name')



In [27]:
gene_embeddings_with_ensembl_id[gene_embeddings_with_ensembl_id.ensembl_id == 'ENSG00000222005']

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,3065,3066,3067,3068,3069,3070,3071,gene_name,ensembl_id,gene_type
83316,-0.017322,0.015861,-0.009243,0.002859,0.002974,-0.015873,0.029082,-0.009793,0.026003,0.014036,...,-0.027633,0.008933,-0.018642,0.021656,-0.008706,-0.005382,-0.021889,LINC01118,ENSG00000222005,
83105,-0.028728,0.013531,-0.01305,-0.003901,0.019633,-0.024284,0.029168,0.003008,0.012114,0.019853,...,-0.017858,-0.001864,-0.021449,0.01812,-0.018822,-0.017803,-0.023334,LINC01119,ENSG00000222005,


In [28]:
import numpy as np

matching_ensembl_ids = pd.DataFrame(major_ensembl_ids, columns=['ensembl_id']).merge(gene_embeddings_with_ensembl_id, left_on='ensembl_id', right_on='ensembl_id')[['gene_name', 'ensembl_id']]
ensembl_id_counts = matching_ensembl_ids.ensembl_id.value_counts()
ensembl_id_counts[ensembl_id_counts > 1]


ensembl_id
ENSG00000000003    2
ENSG00000243485    2
ENSG00000222005    2
ENSG00000264405    2
ENSG00000204792    2
ENSG00000201388    2
ENSG00000222345    2
ENSG00000276234    2
ENSG00000187838    2
ENSG00000202377    2
ENSG00000226364    2
ENSG00000269433    2
ENSG00000183598    2
ENSG00000267151    2
ENSG00000236790    2
ENSG00000264073    2
ENSG00000265134    2
ENSG00000238936    2
ENSG00000270722    2
ENSG00000226444    2
ENSG00000269099    2
ENSG00000204397    2
ENSG00000249532    2
ENSG00000206603    2
ENSG00000284917    2
ENSG00000268942    2
ENSG00000206903    2
ENSG00000063587    2
ENSG00000274020    2
ENSG00000255154    2
ENSG00000090857    2
ENSG00000254508    2
ENSG00000206897    2
ENSG00000251866    2
ENSG00000226419    2
ENSG00000264448    2
ENSG00000206785    2
ENSG00000197927    2
ENSG00000245080    2
ENSG00000269586    2
ENSG00000227518    2
ENSG00000145491    2
ENSG00000221164    2
ENSG00000250331    2
ENSG00000223770    2
ENSG00000269955    2
ENSG00000207187    2
EN

In [29]:
# Get embeddings without metadata columns
embedding_cols = [col for col in gene_embeddings_with_ensembl_id.columns 
                 if col not in ['gene_name', 'ensembl_id', 'gene_type']]

# Group by ensembl_id and take mean of embeddings
merged_embeddings = (gene_embeddings_with_ensembl_id
    .groupby('ensembl_id')[embedding_cols]
    .mean()
    .reset_index())

# Renormalize the embeddings
embedding_values = merged_embeddings[embedding_cols].values
norms = np.linalg.norm(embedding_values, axis=1, keepdims=True)
merged_embeddings[embedding_cols] = embedding_values / norms

# Verify the results
print(f"Original shape: {gene_embeddings_with_ensembl_id.shape}")
print(f"After merging duplicates: {merged_embeddings.shape}")

# Verify all duplicates are resolved
duplicate_check = merged_embeddings.ensembl_id.value_counts()
print("\nNumber of remaining duplicates:", (duplicate_check > 1).sum())

Original shape: (37220, 3075)
After merging duplicates: (36573, 3073)

Number of remaining duplicates: 0


In [30]:
merged_embeddings.ensembl_id

0        ENSG00000000003
1        ENSG00000000005
2        ENSG00000000419
3        ENSG00000000457
4        ENSG00000000460
              ...       
36568    ENSGR0000230542
36569    ENSGR0000234958
36570    ENSGR0000236871
36571    ENSGR0000237040
36572    ENSGR0000265658
Name: ensembl_id, Length: 36573, dtype: object

In [31]:
# # Get the embedding values without the metadata columns
# embedding_cols = [col for col in merged_embeddings.columns 
#                  if col not in ['ensembl_id']]

# # Create a mapping from major_ensembl_ids to column indices in cell_gene_matrix
# gene_idx_map = {gene_id: idx for idx, gene_id in enumerate(major_ensembl_ids)}

# # Find which embeddings correspond to genes in our expression matrix
# # and get their indices in the correct order
# valid_indices = []
# embedding_indices = []
# for i, ensembl_id in enumerate(merged_embeddings.ensembl_id):
#     if ensembl_id in gene_idx_map:
#         valid_indices.append(gene_idx_map[ensembl_id])
#         embedding_indices.append(i)

# # Create the reordered embedding matrix
# embedding_matrix = merged_embeddings[embedding_cols].iloc[embedding_indices].values.T


In [32]:
def create_embedding_matrix(merged_embeddings, major_ensembl_ids):
    """
    Create a reordered embedding matrix that aligns gene embeddings with expression matrix columns.
    
    Args:
        merged_embeddings (pd.DataFrame): DataFrame containing gene embeddings with 'ensembl_id' column
        major_ensembl_ids (pd.Series): Series of Ensembl IDs in the order they appear in expression matrix
    
    Returns:
        tuple: (embedding_matrix, valid_indices)
            - embedding_matrix: numpy array of shape (n_embedding_dims, n_valid_genes)
            - valid_indices: list of indices mapping to original expression matrix columns
    """
    # Get the embedding values without the metadata columns
    embedding_cols = [col for col in merged_embeddings.columns 
                     if col not in ['ensembl_id']]
    
    # Create a mapping from major_ensembl_ids to column indices in cell_gene_matrix
    gene_idx_map = {gene_id: idx for idx, gene_id in enumerate(major_ensembl_ids)}
    
    # Find which embeddings correspond to genes in our expression matrix
    # and get their indices in the correct order
    valid_indices = []
    embedding_indices = []
    for i, ensembl_id in enumerate(merged_embeddings.ensembl_id):
        if ensembl_id in gene_idx_map:
            valid_indices.append(gene_idx_map[ensembl_id])
            embedding_indices.append(i)
    
    # Create the reordered embedding matrix
    embedding_matrix = merged_embeddings[embedding_cols].iloc[embedding_indices].values.T
    
    return embedding_matrix, valid_indices

In [33]:
def create_cell_embeddings(expression_matrix, embedding_matrix, valid_indices):
    """
    Create normalized cell embeddings from gene expression data and gene embeddings.
    
    Args:
        expression_matrix: scipy.sparse.csr_matrix or numpy array of shape (n_cells, n_genes)
        embedding_matrix: numpy array of shape (n_embedding_dims, n_valid_genes)
        valid_indices: list of indices to select genes that have embeddings
    
    Returns:
        numpy array of shape (n_cells, n_embedding_dims) containing normalized cell embeddings
    """
    # Select only the columns from expression matrix that have corresponding embeddings
    filtered_expression = expression_matrix[:, valid_indices]
    
    # Perform the matrix multiplication (n_cells x n_embedding_dimensions)
    cell_embeddings = filtered_expression @ embedding_matrix.T
    
    # Normalize the cell embeddings
    norms = np.linalg.norm(cell_embeddings, axis=1, keepdims=True)
    cell_embeddings = cell_embeddings / norms
    
    return cell_embeddings


In [34]:
embedding_matrix, valid_indices = create_embedding_matrix(merged_embeddings, major_ensembl_ids)

In [36]:

cell_embeddings = create_cell_embeddings(cell_gene_matrix, embedding_matrix, valid_indices)
print(f"Shape of cell embeddings: {cell_embeddings.shape}")
print(f"Number of genes used: {len(valid_indices)} out of {len(major_ensembl_ids)} total genes")

Shape of cell embeddings: (100000, 3072)
Number of genes used: 33258 out of 61759 total genes


In [37]:
cell_embeddings

array([[-0.01288147,  0.00576901, -0.0066504 , ..., -0.00508329,
        -0.00690817, -0.01344143],
       [-0.01111397,  0.00730632, -0.0066569 , ..., -0.00482406,
        -0.00684726, -0.01372927],
       [-0.01294018,  0.00668749, -0.00683006, ..., -0.00455581,
        -0.00719775, -0.01300113],
       ...,
       [-0.01232971,  0.00845843, -0.00676753, ..., -0.00433764,
        -0.00803416, -0.01278692],
       [-0.01180067,  0.00788706, -0.00691997, ..., -0.00421921,
        -0.00748141, -0.01385786],
       [-0.01259188,  0.00802049, -0.00677291, ..., -0.00467154,
        -0.00810091, -0.01343567]])

In [38]:


# # Select only the columns from cell_gene_matrix that have corresponding embeddings
# filtered_expression = cell_gene_matrix[:, valid_indices]

# # Perform the matrix multiplication
# # This will give us (n_cells x n_embedding_dimensions)
# cell_embeddings = filtered_expression @ embedding_matrix.T
# # Normalize the cell embeddings
# norms = np.linalg.norm(cell_embeddings, axis=1, keepdims=True)
# cell_embeddings = cell_embeddings / norms


# print(f"Shape of cell embeddings: {cell_embeddings.shape}")
# print(f"Number of genes used: {len(valid_indices)} out of {len(major_ensembl_ids)} total genes")

In [39]:
cell_embeddings

array([[-0.01288147,  0.00576901, -0.0066504 , ..., -0.00508329,
        -0.00690817, -0.01344143],
       [-0.01111397,  0.00730632, -0.0066569 , ..., -0.00482406,
        -0.00684726, -0.01372927],
       [-0.01294018,  0.00668749, -0.00683006, ..., -0.00455581,
        -0.00719775, -0.01300113],
       ...,
       [-0.01232971,  0.00845843, -0.00676753, ..., -0.00433764,
        -0.00803416, -0.01278692],
       [-0.01180067,  0.00788706, -0.00691997, ..., -0.00421921,
        -0.00748141, -0.01385786],
       [-0.01259188,  0.00802049, -0.00677291, ..., -0.00467154,
        -0.00810091, -0.01343567]])

In [40]:
def load_cell_metadata(file_path, start_row=0, n_rows=1000, columns=None):
    """
    Load metadata for specific cells.
    
    Args:
        file_path: Path to h5ad file
        start_row: Starting row index
        n_rows: Number of rows to load
        columns: List of metadata columns to load (if None, load all)
    
    Returns:
        Dictionary of metadata arrays
    """
    with h5py.File(file_path, 'r') as f:
        metadata = {}
        obs_keys = list(f['obs'].keys()) if columns is None else columns
        for key in obs_keys:
            if key in f['obs']:
                column_group = f['obs'][key]
                if isinstance(column_group, h5py.Dataset):
                    metadata[key] = column_group[start_row:start_row + n_rows]
                else:
                    if 'categories' in column_group and 'codes' in column_group:
                        # Get categories and decode from bytes to strings
                        categories = [cat.decode('utf-8') for cat in column_group['categories'][:]]
                        codes = column_group['codes'][start_row:start_row + n_rows]
                        metadata[key] = np.array([categories[code] for code in codes])
    return metadata

# Example usage:
metadata = load_cell_metadata(file_path, start_row=0, n_rows=1136219, 
                            columns=['cell_type', 'total_counts', 'broad_cell_class', 'donor_id'])
print("Loaded metadata keys:", list(metadata.keys()))

# Print first few values of each column
for key in metadata:
    print(f"\nFirst 5 values of {key}:")
    print(metadata[key][:5])

Loaded metadata keys: ['cell_type', 'total_counts', 'broad_cell_class', 'donor_id']

First 5 values of cell_type:
['naive thymus-derived CD4-positive, alpha-beta T cell' 'B cell' 'B cell'
 'B cell' 'CD8-positive, alpha-beta T cell']

First 5 values of total_counts:
[648388. 404690. 579976. 496511. 453314.]

First 5 values of broad_cell_class:
['t cell' 'lymphocyte of b lineage' 'lymphocyte of b lineage'
 'lymphocyte of b lineage' 't cell']

First 5 values of donor_id:
['TSP2' 'TSP2' 'TSP2' 'TSP2' 'TSP2']


In [41]:
cell_type_labels = pd.Categorical(pd.Series(metadata['broad_cell_class'])[:len(cell_embeddings)]).codes
cell_type_labels


array([35, 25, 25, ...,  4, 32, 32], dtype=int8)

In [87]:
with h5py.File(file_path, 'r') as f:
    print(pd.Series(f['obs']['broad_cell_class']['categories']))

0                    b'adventitial cell'
1            b'cardiac endothelial cell'
2            b'ciliated epithelial cell'
3        b'conjunctival epithelial cell'
4              b'connective tissue cell'
5                    b'contractile cell'
6                      b'dendritic cell'
7                b'duct epithelial cell'
8                b'ecto-epithelial cell'
9                b'endo-epithelial cell'
10                   b'endothelial cell'
11                    b'epithelial cell'
12            b'epithelial cell of lung'
13             b'erythroid lineage cell'
14                           b'fat cell'
15                   b'female germ cell'
16                         b'fibroblast'
17                           b'follicle'
18          b'glandular epithelial cell'
19                         b'glial cell'
20                        b'granulocyte'
21                 b'hematopoietic cell'
22                         b'hepatocyte'
23               b'innate lymphoid cell'
24         b'int

In [42]:
donor_ids = pd.Categorical(metadata['donor_id'][:len(cell_embeddings)]).codes
donor_ids

array([7, 7, 7, ..., 7, 7, 7], dtype=int8)

In [79]:
cell_embeddings_pdf = pd.DataFrame(cell_embeddings, columns=list(range(cell_embeddings.shape[1])))
cell_embeddings_pdf['cell_type'] = cell_type_labels
cell_embeddings_pdf['donor_id'] = donor_ids
cell_embeddings_pdf['scvi_leiden_donorassay_full'] = scvi_leiden_donorassay_full


ValueError: Invalid dataset identifier (invalid dataset identifier)

In [44]:
cell_embeddings.shape

(100000, 3072)

In [45]:
from sklearn.decomposition import PCA
import umap
import matplotlib.pyplot as plt
import plotly.express as px

# Convert all column names to strings before applying PCA
# features_24_weeks_df.columns = features_24_weeks_df.columns.astype(str)

# Now run PCA
# pca = PCA(n_components=50)
# pca_embeddings = pca.fit_transform(features_24_weeks_df.drop(columns=['drug dose', 'external_id', 'ifn status', "subject sex", "subject age", "subject_id"]))
# Then apply UMAP to the PCA results
reducer = umap.UMAP(random_state=42)
umap_sample_pdf = cell_embeddings_pdf.sample(1000)
umap_embeddings = reducer.fit_transform(umap_sample_pdf.sample(1000).drop(columns=['cell_type', 'donor_id']))

# Create a DataFrame with the UMAP coordinates using the original DataFrame's index
umap_df = pd.DataFrame(umap_embeddings, 
                      columns=['UMAP1', 'UMAP2'], 
                      index=umap_sample_pdf.index)  # Use the original DataFrame's index

# Merge with sample attributes to get metadata for coloring
umap_df = umap_df.merge(umap_sample_pdf[['cell_type', 'donor_id']], 
                       left_index=True, 
                       right_index=True)
# Create the plot
fig = px.scatter(umap_df, x='UMAP1', y='UMAP2', color='cell_type', opacity=0.7,
                 title='UMAP Visualization of Gene Expression Embeddings')

# Update layout
fig.update_layout(
    title={
        'y':0.95,
        'x':0.5,
        'xanchor': 'center',
        'yanchor': 'top'
    }
)

fig.show()

  warn(


In [46]:
import plotly.express as px

px.histogram(cell_embeddings_pdf.cell_type.sort_values())

In [53]:
# Create a cross-tabulation of donor_id and cell_type
heatmap_data = pd.crosstab(cell_embeddings_pdf.donor_id, cell_embeddings_pdf.cell_type)

# Create heatmap using plotly
import plotly.express as px
import numpy as np

# Apply log10 transform to the data (adding 1 to avoid log(0))
log_data = np.log10(heatmap_data.values + 1)

# Create regular heatmap with log-transformed data
fig = px.imshow(
    log_data,
    labels=dict(x="Cell Type", y="Donor ID", color="Count"),
    x=heatmap_data.columns,
    y=heatmap_data.index,
    color_continuous_scale='Viridis',
    title='Cell Type Distribution Across Donors (Log Scale)',
    aspect='auto'
)

# Update hover template to show both log and linear values
fig.data[0].customdata = heatmap_data.values
fig.data[0].hovertemplate = "Cell Type: %{x}<br>Donor ID: %{y}<br>Count: %{customdata:.0f}<br>Log10 Count: %{z:.2f}<extra></extra>"

# Create tick values for the colorbar (in log space)
tick_values = np.linspace(log_data.min(), log_data.max(), 6)
# Convert tick values back to linear space for labels
tick_labels = [f"{int(10**x - 1)}" for x in tick_values]

# Update layout and colorbar
fig.update_layout(
    xaxis_title='Cell Type',
    yaxis_title='Donor ID',
    coloraxis=dict(
        colorbar=dict(
            title='Count',
            tickvals=tick_values,
            ticktext=tick_labels
        )
    )
)

fig.show()

In [56]:
# Get value counts and identify categories with < 200 samples
category_counts = pd.Series(cell_embeddings_pdf.cell_type.value_counts())
small_categories = category_counts[category_counts < 200].index

# Create new column with remapped categories
cell_embeddings_pdf['cell_type_grouped'] = cell_embeddings_pdf.cell_type
cell_embeddings_pdf.loc[cell_embeddings_pdf.cell_type.isin(small_categories), 'cell_type_grouped'] = max(cell_embeddings_pdf.cell_type) + 1

In [57]:
px.histogram(cell_embeddings_pdf.cell_type_grouped.sort_values())

In [58]:
X = cell_embeddings_pdf.copy()
X.columns = X.columns.astype(str)
y = X.cell_type_grouped
X.drop(columns=['cell_type','cell_type_grouped'], inplace=True)
embedding_features_indicator = pd.Series([isinstance(x, int) for x in X.columns])

print("Shape of embedding features indicator:", embedding_features_indicator.shape)
print("Shape of filtered features matrix:", X.shape)


Shape of embedding features indicator: (3073,)
Shape of filtered features matrix: (100000, 3073)


In [None]:
# from sklearn.model_selection import GroupShuffleSplit

# # Create group-wise split
# gss = GroupShuffleSplit(n_splits=1, test_size=0.3, random_state=42)
# train_idx, test_idx = next(gss.split(X, y, groups=X.donor_id))

# # Split the data using the indices
# X_train = X.drop(columns=['donor_id']).iloc[train_idx]
# X_test = X.drop(columns=['donor_id']).iloc[test_idx]
# y_train = y.iloc[train_idx]
# y_test = y.iloc[test_idx]


In [67]:
from sklearn.metrics import classification_report
from sklearn.ensemble import RandomForestClassifier
from lightgbm import LGBMClassifier
from sklearn.neighbors import KNeighborsClassifier

# Define the donors we want to evaluate
test_donors = [0, 5, 7]
results = []

# Perform cross-validation, holding out one donor at a time
for test_donor in test_donors:
    print(f"\n=== Cross Validation Fold: Testing on Donor {test_donor} ===")
    
    # Create initial train/test split based on donor
    train_mask = X.donor_id != test_donor
    test_indices = X[~train_mask].index
    
    # Subsample training data to get 200 samples per cell type
    train_indices = []
    for class_label in y.unique():
        # Get indices for this class from non-test donors
        class_mask = (y == class_label) & train_mask
        class_indices = X[class_mask].index
        
        # Randomly sample up to 1000 indices
        if len(class_indices) > 0:
            n_samples = min(1000, len(class_indices))
            sampled_indices = np.random.choice(class_indices, size=n_samples, replace=False)
            train_indices.extend(sampled_indices)
    
    # Create the final train/test splits
    X_train = X.drop(columns=['donor_id']).iloc[train_indices]
    X_test = X.drop(columns=['donor_id']).iloc[test_indices]
    y_train = y.iloc[train_indices]
    y_test = y.iloc[test_indices]
    
    print(f"Training set size: {len(X_train)}")
    print(f"Test set size: {len(X_test)}")
    print("\nTraining class distribution:")
    print(y_train.value_counts().sort_index())
    
    # Train and evaluate models
    models = {
        'KNN': KNeighborsClassifier(n_neighbors=10),
        'Random Forest': RandomForestClassifier(random_state=42),
        'LightGBM': LGBMClassifier(random_state=42, class_weight='balanced')
    }
    
    for name, model in models.items():
        print(f"\n{name} Results:")
        print("-" * 50)
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)

        valid_classes = sorted(set(y_test))
    # Generate report only for classes that exist in the data
        report = classification_report(y_test, y_pred, 
                        labels=valid_classes,
                        zero_division=0,
                        output_dict=True,
                        )        
        # Store results
        results.append({
            'test_donor': test_donor,
            'model': name,
            # 'accuracy': report['accuracy'],
            'macro_avg_f1': report['macro avg']['f1-score'],
            'weighted_avg_f1': report['weighted avg']['f1-score'],
            'train_size': len(X_train),
            'test_size': len(X_test)
        })
        
        print(classification_report(y_test, y_pred))

# Convert results to DataFrame for easy viewing
results_df = pd.DataFrame(results)
print("\nSummary of Results:")
print(results_df.round(3))


=== Cross Validation Fold: Testing on Donor 0 ===
Training set size: 18593
Test set size: 4851

Training class distribution:
cell_type_grouped
0      406
1     1000
2      383
4     1000
5      194
6      217
8      660
9     1000
10    1000
11     722
12     313
15    1000
17     518
19    1000
20     373
22    1000
23    1000
25    1000
28    1000
31     276
32    1000
34    1000
35    1000
36    1000
37     531
Name: count, dtype: int64

KNN Results:
--------------------------------------------------



Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.



              precision    recall  f1-score   support

           0       0.08      0.86      0.15        44
           1       0.00      0.00      0.00         0
           2       1.00      1.00      1.00        32
           4       0.60      0.37      0.46       243
           5       0.34      0.30      0.32        33
           6       0.00      0.00      0.00       375
           8       0.00      0.00      0.00         6
           9       0.88      0.77      0.82       559
          10       0.06      0.80      0.11        10
          11       0.81      0.91      0.86        97
          15       0.33      0.27      0.30       233
          17       0.91      0.33      0.48      1470
          19       0.85      0.44      0.58       146
          20       0.04      1.00      0.07         2
          22       0.04      0.78      0.07         9
          23       0.00      0.00      0.00         0
          25       0.63      0.92      0.75       167
          28       0.67    


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.



              precision    recall  f1-score   support

           0       0.11      0.77      0.20        44
           1       0.00      0.00      0.00         0
           2       0.97      1.00      0.98        32
           4       0.60      0.45      0.52       243
           5       0.22      0.06      0.10        33
           6       0.36      0.01      0.02       375
           8       0.00      0.00      0.00         6
           9       0.88      0.88      0.88       559
          10       0.09      0.70      0.15        10
          11       0.92      0.87      0.89        97
          15       0.40      0.55      0.47       233
          17       0.88      0.70      0.78      1470
          19       0.81      0.58      0.68       146
          20       0.17      1.00      0.29         2
          22       0.03      0.78      0.06         9
          23       0.00      0.00      0.00         0
          25       0.94      0.90      0.92       167
          28       0.72    


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.



              precision    recall  f1-score   support

           0       0.10      0.64      0.18        44
           1       0.00      0.00      0.00         0
           2       0.94      1.00      0.97        32
           4       0.65      0.49      0.56       243
           5       0.47      0.45      0.46        33
           6       0.80      0.02      0.04       375
           8       0.01      0.17      0.02         6
           9       0.90      0.87      0.88       559
          10       0.11      0.70      0.19        10
          11       0.93      0.82      0.87        97
          12       0.00      0.00      0.00         0
          15       0.47      0.56      0.51       233
          17       0.65      0.36      0.46      1470
          19       0.88      0.86      0.87       146
          20       0.05      1.00      0.09         2
          22       0.05      0.78      0.09         9
          23       0.00      0.00      0.00         0
          25       0.97    


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.



              precision    recall  f1-score   support

           0       0.36      0.19      0.24       258
           1       0.00      0.00      0.00         0
           2       0.45      1.00      0.62         5
           4       0.71      0.72      0.72       884
           5       0.36      0.67      0.47        81
           6       0.00      0.00      0.00         0
           8       0.00      0.00      0.00       101
           9       0.90      0.47      0.62      2783
          10       0.26      0.32      0.29       234
          11       0.80      0.96      0.88       246
          12       0.41      0.41      0.41        22
          15       0.21      0.73      0.33       811
          17       0.15      0.24      0.18        51
          19       0.86      0.73      0.79       583
          20       0.14      0.87      0.25        38
          22       0.31      0.78      0.45       560
          23       0.92      0.59      0.72      2201
          25       0.68    


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.



              precision    recall  f1-score   support

           0       0.33      0.12      0.18       258
           1       0.00      0.00      0.00         0
           2       0.44      0.80      0.57         5
           4       0.61      0.78      0.68       884
           5       0.51      0.48      0.50        81
           6       0.00      0.00      0.00         0
           8       0.00      0.00      0.00       101
           9       0.87      0.63      0.73      2783
          10       0.36      0.50      0.42       234
          11       0.90      0.97      0.94       246
          12       0.41      0.41      0.41        22
          15       0.21      0.68      0.32       811
          17       0.24      0.29      0.26        51
          19       0.81      0.81      0.81       583
          20       0.21      0.84      0.34        38
          22       0.32      0.83      0.46       560
          23       0.85      0.80      0.82      2201
          25       0.90    


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.



In [64]:
report

'              precision    recall  f1-score   support\n\n           0       0.08      0.84      0.15        44\n           1       0.00      0.00      0.00         0\n           2       1.00      1.00      1.00        32\n           4       0.62      0.35      0.44       243\n           5       0.44      0.45      0.45        33\n           6       0.00      0.00      0.00       375\n           8       0.00      0.00      0.00         6\n           9       0.89      0.78      0.83       559\n          10       0.12      0.80      0.21        10\n          11       0.84      0.90      0.87        97\n          15       0.31      0.24      0.27       233\n          17       0.93      0.34      0.50      1470\n          19       0.84      0.36      0.50       146\n          20       0.03      1.00      0.06         2\n          22       0.03      0.78      0.06         9\n          23       0.00      0.00      0.00         0\n          25       0.72      0.89      0.80       167\n       

In [94]:
with h5py.File(file_path, 'r') as f:
    cell_class = pd.Series(f['obs']['broad_cell_class']['categories'], name="cell_class_name")

In [95]:
lgbm_results_donor_7 = pd.DataFrame({
  "class": [0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 15, 17, 19, 20, 22, 23, 25, 28, 31, 32, 34, 35, 36, 37],
  "precision": [0.33, 0.0, 0.44, 0.61, 0.51, 0.0, 0.0, 0.87, 0.36, 0.9, 0.41, 0.21, 0.24, 0.81, 0.21, 0.32, 0.85, 0.9, 0.95, 0.11, 0.54, 0.01, 0.96, 0.02, 0.17],
  "recall": [0.12, 0.0, 0.8, 0.78, 0.48, 0.0, 0.0, 0.63, 0.5, 0.97, 0.41, 0.68, 0.29, 0.81, 0.84, 0.83, 0.8, 0.96, 0.89, 0.15, 0.13, 0.19, 0.75, 0.08, 0.03],
  "f1_score": [0.18, 0.0, 0.57, 0.68, 0.5, 0.0, 0.0, 0.73, 0.42, 0.94, 0.41, 0.32, 0.26, 0.81, 0.34, 0.46, 0.82, 0.93, 0.92, 0.12, 0.21, 0.02, 0.85, 0.04, 0.05],
  "support": [258, 0, 5, 884, 81, 0, 101, 2783, 234, 246, 22, 811, 51, 583, 38, 560, 2201, 2563, 2219, 27, 3861, 16, 4910, 97, 160]
})

In [96]:
lgbm_results_donor_7.merge(cell_class, left_index=True, right_index=True)

Unnamed: 0,class,precision,recall,f1_score,support,cell_class_name
0,0,0.33,0.12,0.18,258,b'adventitial cell'
1,1,0.0,0.0,0.0,0,b'cardiac endothelial cell'
2,2,0.44,0.8,0.57,5,b'ciliated epithelial cell'
3,4,0.61,0.78,0.68,884,b'conjunctival epithelial cell'
4,5,0.51,0.48,0.5,81,b'connective tissue cell'
5,6,0.0,0.0,0.0,0,b'contractile cell'
6,8,0.0,0.0,0.0,101,b'dendritic cell'
7,9,0.87,0.63,0.73,2783,b'duct epithelial cell'
8,10,0.36,0.5,0.42,234,b'ecto-epithelial cell'
9,11,0.9,0.97,0.94,246,b'endo-epithelial cell'
