# GenePT and scGPT cell classification performance on Tabula Sapiens

This notebook downloads (if necessary) the [Tabula Sapiens data set](https://cellxgene.cziscience.com/collections/e5f58829-1a66-40b5-a624-9046778e74f5)
and uses GenePT and scGPT pretrained embeddings to embed the cells and then tests classification performance.  Tabula Sapiens is a benchmark dataset, so the models we train cannot be used for real-world applications. Rather, we are training these classifiers to benchmark our GenePT embeddings and pretrained scGPT embeddings on a large dataset with a high number of cell types.


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd

pd.set_option("display.max_rows", 1000)
import sys
from pathlib import Path

repo_dir = Path.cwd().parent.absolute()
sys.path.append(str(repo_dir))

In [3]:
from pathlib import Path

from src.utils import setup_data_dir

setup_data_dir()
data_dir = repo_dir / "data"

File already exists at /Users/rj/personal/GenePT-tools/data/GenePT_emebdding_v2.zip
Extracting files...
Extracting GenePT_emebdding_v2/
Skipping GenePT_emebdding_v2/NCBI_UniProt_summary_of_genes.json - already exists with same size
Skipping GenePT_emebdding_v2/GenePT_gene_embedding_ada_text.pickle - already exists with same size
Skipping GenePT_emebdding_v2/GenePT_gene_protein_embedding_model_3_text.pickle. - already exists with same size
Skipping GenePT_emebdding_v2/NCBI_summary_of_genes.json - already exists with same size
Extraction complete!
Setup finished!


In [4]:
embed_scgpt_pdf = pd.read_parquet("../data/tabula_sapiens_100k_scgpt_embedding.parquet")
embed_scgpt_pdf.columns

Index(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
       ...
       '505', '506', '507', '508', '509', '510', '511', 'cell_type',
       'broad_cell_class', 'donor_id'],
      dtype='object', length=515)

In [6]:
embed_genept_pdf = pd.read_parquet(
    "../data/tabula_sapiens_100k_genept_embedding.parquet"
)
embed_genept_pdf.columns

Index(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
       ...
       '3065', '3066', '3067', '3068', '3069', '3070', '3071', 'cell_type',
       'broad_cell_class', 'donor_id'],
      dtype='object', length=3075)

In [7]:
import matplotlib.pyplot as plt
import numpy as np
import plotly.express as px
import umap
from sklearn.decomposition import PCA


def umap_embed(embed_pdf, n_samples=2000):
    # Then apply UMAP to the PCA results
    reducer = umap.UMAP(random_state=42)
    np.random.seed(42)
    random_indices = np.random.choice(embed_pdf.shape[0], size=n_samples, replace=False)
    umap_sample_pdf = embed_pdf.iloc[random_indices].drop(
        columns=["cell_type", "donor_id", "broad_cell_class"]
    )
    umap_embeddings = reducer.fit_transform(umap_sample_pdf)

    umap_df = pd.DataFrame(
        umap_embeddings, columns=["UMAP1", "UMAP2"], index=umap_sample_pdf.index
    ).merge(embed_pdf, left_index=True, right_index=True)
    return umap_df


umap_embeddings_scgpt = umap_embed(embed_scgpt_pdf)
umap_embeddings_genept = umap_embed(embed_genept_pdf)

  warn(
  warn(


In [8]:
for embed_pdf, name in zip(
    [umap_embeddings_scgpt, umap_embeddings_genept], ["scGPT", "GenePT"]
):
    # Create the plot
    fig = px.scatter(
        embed_pdf,
        x="UMAP1",
        y="UMAP2",
        color="cell_type",
        opacity=0.7,
        title=name,
    )

    # Update layout to make the data area square
    fig.update_layout(
        title={"y": 0.95, "x": 0.5, "xanchor": "center", "yanchor": "top"},
        width=1400,  # Wider to accommodate legend
        height=800,
    )

    # Make the plot area square by adjusting margins
    fig.update_layout(
        margin=dict(r=200),  # Add right margin for legend
        xaxis=dict(domain=[0, 0.8]),  # Restrict plot area width to make it square
        yaxis=dict(scaleanchor="x", scaleratio=1),
    )

    fig.show()

In [10]:
import plotly.express as px

px.histogram(embed_scgpt_pdf.broad_cell_class.sort_values())

In [11]:
# Create a cross-tabulation of donor_id and cell_type
heatmap_data = pd.crosstab(embed_scgpt_pdf.donor_id, embed_scgpt_pdf.broad_cell_class)

# Create heatmap using plotly
import numpy as np
import plotly.express as px

# Apply log10 transform to the data (adding 1 to avoid log(0))
log_data = np.log10(heatmap_data.values + 1)

# Create regular heatmap with log-transformed data
fig = px.imshow(
    log_data,
    labels=dict(x="Cell Type", y="Donor ID", color="Count"),
    x=heatmap_data.columns,
    y=heatmap_data.index,
    color_continuous_scale="Viridis",
    title="Cell Type Distribution Across Donors (Log Scale)",
    aspect="auto",
)

# Update hover template to show both log and linear values
fig.data[0].customdata = heatmap_data.values
fig.data[0].hovertemplate = (
    "Cell Type: %{x}<br>Donor ID: %{y}<br>Count: %{customdata:.0f}<br>Log10 Count: %{z:.2f}<extra></extra>"
)

# Create tick values for the colorbar (in log space)
tick_values = np.linspace(log_data.min(), log_data.max(), 6)
# Convert tick values back to linear space for labels
tick_labels = [f"{int(10**x - 1)}" for x in tick_values]

# Update layout and colorbar
fig.update_layout(
    xaxis_title="Cell Type",
    yaxis_title="Donor ID",
    height=700,  # Adjusted height (increase as needed)
    coloraxis=dict(
        colorbar=dict(title="Count", tickvals=tick_values, ticktext=tick_labels)
    ),
)

fig.show()

In [12]:
def create_cell_type_groups(df, min_samples=600):
    """Create grouped cell types, combining rare types into 'other'

    Args:
        df: DataFrame containing 'broad_cell_class' column
        min_samples: Minimum number of samples required to keep a category

    Returns:
        Series with grouped cell types
    """
    # Get value counts and identify small categories
    category_counts = pd.Series(df.broad_cell_class.value_counts())
    small_categories = category_counts[category_counts < min_samples].index

    # Get existing categories and add 'other'
    existing_categories = df.broad_cell_class.cat.categories
    new_categories = pd.Index(existing_categories).append(pd.Index(["other"]))

    # Create new column with expanded categories
    cell_type_grouped = df.broad_cell_class.astype(str).astype("category")
    cell_type_grouped = cell_type_grouped.cat.set_categories(new_categories)

    # Assign the 'other' category
    cell_type_grouped.loc[df.broad_cell_class.isin(small_categories)] = "other"

    return cell_type_grouped


# Apply the function to both dataframes
embed_scgpt_pdf["cell_type_grouped"] = create_cell_type_groups(embed_scgpt_pdf)
embed_genept_pdf["cell_type_grouped"] = create_cell_type_groups(embed_genept_pdf)

In [19]:
embed_genept_pdf.shape

(100000, 3076)

In [13]:

embed_genept_pdf.cell_type_grouped.value_counts()


cell_type_grouped
t cell                             14053
stromal cell                       13013
myeloid leukocyte                   8565
lymphocyte of b lineage             8499
contractile cell                    7916
fibroblast                          6995
endothelial cell                    6019
stem cell                           5937
granulocyte                         5797
intestinal epithelial cell          5764
transitional epithelial cell        5384
other                               3180
innate lymphoid cell                2507
glandular epithelial cell           1988
epithelial cell                     1810
cardiac endothelial cell            1088
epithelial cell of lung              819
endo-epithelial cell                 666
conjunctival epithelial cell           0
ciliated epithelial cell               0
connective tissue cell                 0
meso-epithelial cell                   0
stratified epithelial cell             0
dendritic cell                         

In [35]:
train_test_counts =pd.merge(
    embed_genept_pdf.cell_type_grouped.value_counts(),
    embed_genept_pdf[embed_genept_pdf.donor_id == "TSP1"].cell_type_grouped.value_counts().rename("TSP1"),
    how="outer",
    left_index=True,
    right_index=True,
).merge(
    embed_genept_pdf[embed_genept_pdf.donor_id != "TSP1"].cell_type_grouped.value_counts().rename("not_TSP1"),
    how="outer",
    left_index=True,
    right_index=True,
).merge(
    embed_genept_pdf[embed_genept_pdf.donor_id == "TSP2"].cell_type_grouped.value_counts().rename("TSP2"),
    how="outer",
    left_index=True,
    right_index=True,
).merge(
    embed_genept_pdf[embed_genept_pdf.donor_id != "TSP2"].cell_type_grouped.value_counts().rename("not_TSP2"),
    how="outer",
    left_index=True,
    right_index=True,
).merge(
    embed_genept_pdf[embed_genept_pdf.donor_id == "TSP14"].cell_type_grouped.value_counts().rename("TSP14"),
    how="outer",
    left_index=True,
    right_index=True,
).merge(
    embed_genept_pdf[embed_genept_pdf.donor_id != "TSP14"].cell_type_grouped.value_counts().rename("not_TSP14"),
    how="outer",
    left_index=True,
    right_index=True,
)

train_test_counts

Unnamed: 0_level_0,count,TSP1,not_TSP1,TSP2,not_TSP2,TSP14,not_TSP14
cell_type_grouped,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
adventitial cell,0,0,0,0,0,0,0
cardiac endothelial cell,1088,0,1088,0,1088,57,1031
ciliated epithelial cell,0,0,0,0,0,0,0
conjunctival epithelial cell,0,0,0,0,0,0,0
connective tissue cell,0,0,0,0,0,0,0
contractile cell,7916,243,7673,884,7032,633,7283
dendritic cell,0,0,0,0,0,0,0
duct epithelial cell,0,0,0,0,0,0,0
ecto-epithelial cell,0,0,0,0,0,0,0
endo-epithelial cell,666,6,660,101,565,24,642


In [None]:
# X = pd.DataFrame(ref_embed_adata.obsm["X_scGPT"])
# y = ref_embed_adata.obs["broad_cell_class"]
# X["donor_id"] = ref_embed_adata.obs.donor_id.cat.codes.to_numpy()

# # print("Shape of embedding features indicator:", embedding_features_indicator.shape)
# print("Shape of filtered features matrix:", X.shape)

In [None]:
# y == "t cell"

In [85]:
# def get_mask_for_label_excluding_donor(y, label, test_donor):
#     return (y == label) & (X.donor_id != test_donor)


# y = embed_scgpt_pdf.broad_cell_class
# test_donor = "TSP14"
# label = "endo-epithelial cell"
# mask = get_mask_for_label_excluding_donor(y, label, test_donor)
# y.index[mask]
# # embed_scgpt_pdf[mask]
# sample_count = mask.sum()
# n_samples = min(1000, sample_count)

# sampled_indices = pd.Index(
#     np.random.choice(pd.Series(y[mask].index), size=n_samples, replace=False)
# )

In [None]:
# from sklearn.model_selection import GroupShuffleSplit

# # Create group-wise split
# gss = GroupShuffleSplit(n_splits=1, test_size=0.3, random_state=42)
# train_idx, test_idx = next(gss.split(X, y, groups=X.donor_id))

# # Split the data using the indices
# X_train = X.drop(columns=['donor_id']).iloc[train_idx]
# X_test = X.drop(columns=['donor_id']).iloc[test_idx]
# y_train = y.iloc[train_idx]
# y_test = y.iloc[test_idx]

In [None]:
# (y == "t cell").index

In [95]:
# (X.donor_id != test_donor).index

Index(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
       ...
       '99990', '99991', '99992', '99993', '99994', '99995', '99996', '99997',
       '99998', '99999'],
      dtype='object', length=100000)

In [103]:
# del combined_embedding_pdf

In [107]:
embed_scgpt_pdf.index = embed_genept_pdf.index
combined_embedding_pdf = embed_scgpt_pdf.drop(
    columns=["donor_id", "cell_type", "broad_cell_class", "cell_type_grouped"]
).merge(embed_genept_pdf, left_index=True, right_index=True)
combined_embedding_pdf.shape

(100000, 3588)

In [109]:
from lightgbm import LGBMClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
from sklearn.neighbors import KNeighborsClassifier

# Define the donors we want to evaluate
test_donors = ["TSP1", "TSP2", "TSP14"]
results = []


def get_mask_for_label_excluding_donor(y, label, test_donor):
    return (y == label) & (X.donor_id != test_donor)


for embed_pdf, embed_name in zip(
    [combined_embedding_pdf, embed_scgpt_pdf, embed_genept_pdf],
    ["combined", "scGPT", "GenePT"],
):
    X = embed_pdf.drop(columns=["cell_type", "broad_cell_class", "cell_type_grouped"])
    y = embed_pdf.cell_type_grouped

    # Perform cross-validation, holding out one donor at a time
    for test_donor in test_donors:
        print(f"\n=== Cross Validation Fold: Testing on Donor {test_donor} ===")

        # Create initial train/test split based on donor
        train_mask = X.donor_id != test_donor
        test_indices = X[~train_mask].index

        # Subsample training data to get 200 samples per cell type
        train_indices = []
        clipped_train_indices = []
        for class_label in y.unique():

            mask = get_mask_for_label_excluding_donor(y, class_label, test_donor)
            sample_count = mask.sum()
            print(f"{class_label}: {sample_count}/{(y == class_label).sum()}")

            # Randomly sample up to 1000 indices
            if len(mask) > 0:
                train_indices.extend(y[mask].index)

                n_samples = min(1000, sample_count)
                sampled_indices = pd.Index(
                    np.random.choice(
                        pd.Series(y[mask].index), size=n_samples, replace=False
                    )
                )
                clipped_train_indices.extend(sampled_indices)
            else:
                print(f"warning: class '{class_label}' has no samples!")

        # Create the final train/test splits
        X_train = X.drop(columns=["donor_id"]).iloc[train_indices]
        X_train_clipped = X.drop(columns=["donor_id"]).iloc[clipped_train_indices]
        X_test = X.drop(columns=["donor_id"]).iloc[test_indices]
        y_train = y.iloc[train_indices]
        y_train_clipped = y.iloc[clipped_train_indices]
        y_test = y.iloc[test_indices]

        print(y_train.value_counts().sort_index())
        print(y_test.value_counts().sort_index())

        print(f"Training set size: {len(X_train)}")
        print(f"Clipped Training set size: {len(X_train_clipped)}")
        print(f"Test set size: {len(X_test)}")
        print("\nTraining class distribution:")
        print(y_train.value_counts().sort_index())

        # Train and evaluate models
        models = {
            "KNN": KNeighborsClassifier(n_neighbors=10),
            "Random Forest": RandomForestClassifier(random_state=42),
            "LightGBM": LGBMClassifier(random_state=42, class_weight="balanced"),
        }

        for name, model in models.items():
            print(f"\n{name} Results:")
            print("-" * 50)
            if name == "Random Forest":
                model.fit(X_train, y_train)
            else:
                model.fit(X_train_clipped, y_train_clipped)
            y_pred = model.predict(X_test)

            valid_classes = sorted(set(y_test))
            # Generate report only for classes that exist in the data
            report = classification_report(
                y_test,
                y_pred,
                # labels=valid_classes,
                zero_division=0,
                output_dict=True,
            )
            # Store results
            results.append(
                {
                    "embed_name": embed_name,
                    "test_donor": test_donor,
                    "model": name,
                    # 'accuracy': report['accuracy'],
                    "macro_avg_f1": report["macro avg"]["f1-score"],
                    "weighted_avg_f1": report["weighted avg"]["f1-score"],
                    "train_size": len(X_train),
                    "test_size": len(X_test),
                    "report": report,
                }
            )

            print(classification_report(y_test, y_pred))

# Convert results to DataFrame for easy viewing
results_df = pd.DataFrame(results)
print("\nSummary of Results:")
print(results_df.round(3))


=== Cross Validation Fold: Testing on Donor TSP1 ===
t cell: 13636/14053
lymphocyte of b lineage: 8332/8499
innate lymphoid cell: 2498/2507
endothelial cell: 5460/6019
other: 2693/3180
contractile cell: 7673/7916
granulocyte: 5651/5797
myeloid leukocyte: 8195/8565
cardiac endothelial cell: 1088/1088
glandular epithelial cell: 518/1988
epithelial cell: 1800/1810
epithelial cell of lung: 722/819
stem cell: 5500/5937
stromal cell: 12969/13013
fibroblast: 6762/6995
endo-epithelial cell: 660/666
intestinal epithelial cell: 5764/5764
transitional epithelial cell: 5228/5384
cell_type_grouped
adventitial cell                       0
cardiac endothelial cell            1088
ciliated epithelial cell               0
conjunctival epithelial cell           0
connective tissue cell                 0
contractile cell                    7673
dendritic cell                         0
duct epithelial cell                   0
ecto-epithelial cell                   0
endo-epithelial cell                 6


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.



                              precision    recall  f1-score   support

    cardiac endothelial cell       0.00      0.00      0.00         0
            contractile cell       0.48      0.53      0.51       243
        endo-epithelial cell       0.04      0.67      0.08         6
            endothelial cell       0.91      0.80      0.85       559
             epithelial cell       0.07      0.70      0.13        10
     epithelial cell of lung       0.91      0.90      0.90        97
                  fibroblast       0.31      0.26      0.28       233
   glandular epithelial cell       0.49      0.14      0.22      1470
                 granulocyte       0.84      0.78      0.81       146
        innate lymphoid cell       0.03      0.67      0.06         9
  intestinal epithelial cell       0.00      0.00      0.00         0
     lymphocyte of b lineage       0.63      0.89      0.73       167
           myeloid leukocyte       0.76      0.80      0.78       370
                   


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.



                              precision    recall  f1-score   support

    cardiac endothelial cell       0.00      0.00      0.00         0
            contractile cell       0.40      0.58      0.47       243
        endo-epithelial cell       1.00      0.17      0.29         6
            endothelial cell       0.91      0.93      0.92       559
             epithelial cell       0.29      0.80      0.42        10
     epithelial cell of lung       0.96      0.81      0.88        97
                  fibroblast       0.30      0.76      0.43       233
   glandular epithelial cell       0.81      0.03      0.06      1470
                 granulocyte       0.88      0.75      0.81       146
        innate lymphoid cell       0.27      0.67      0.39         9
  intestinal epithelial cell       0.00      0.00      0.00         0
     lymphocyte of b lineage       0.71      0.90      0.79       167
           myeloid leukocyte       0.73      0.86      0.79       370
                   


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.



                              precision    recall  f1-score   support

    cardiac endothelial cell       0.00      0.00      0.00         0
            contractile cell       0.63      0.55      0.59       243
        endo-epithelial cell       0.07      0.50      0.12         6
            endothelial cell       0.92      0.87      0.89       559
             epithelial cell       0.17      0.70      0.28        10
     epithelial cell of lung       0.98      0.91      0.94        97
                  fibroblast       0.44      0.70      0.54       233
   glandular epithelial cell       0.74      0.37      0.49      1470
                 granulocyte       0.89      0.83      0.86       146
        innate lymphoid cell       0.06      0.89      0.11         9
  intestinal epithelial cell       0.00      0.00      0.00         0
     lymphocyte of b lineage       0.98      0.90      0.94       167
           myeloid leukocyte       0.78      0.83      0.81       370
                   


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.



                              precision    recall  f1-score   support

    cardiac endothelial cell       0.00      0.00      0.00         0
            contractile cell       0.83      0.82      0.82       884
        endo-epithelial cell       0.02      0.03      0.02       101
            endothelial cell       0.92      0.56      0.70      2783
             epithelial cell       0.74      0.59      0.66       234
     epithelial cell of lung       0.91      0.99      0.95       246
                  fibroblast       0.18      0.75      0.29       811
   glandular epithelial cell       0.23      0.75      0.35        51
                 granulocyte       0.81      0.95      0.87       583
        innate lymphoid cell       0.31      0.82      0.45       560
  intestinal epithelial cell       0.85      0.88      0.87      2201
     lymphocyte of b lineage       0.99      0.97      0.98      2563
           myeloid leukocyte       0.92      0.89      0.90      2219
                   


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.



                              precision    recall  f1-score   support

    cardiac endothelial cell       0.00      0.00      0.00         0
            contractile cell       0.70      0.88      0.78       884
        endo-epithelial cell       0.00      0.00      0.00       101
            endothelial cell       0.93      0.83      0.87      2783
             epithelial cell       0.84      0.48      0.61       234
     epithelial cell of lung       0.94      0.98      0.96       246
                  fibroblast       0.24      0.96      0.38       811
   glandular epithelial cell       0.81      0.51      0.63        51
                 granulocyte       0.86      0.94      0.90       583
        innate lymphoid cell       0.71      0.58      0.64       560
  intestinal epithelial cell       0.76      0.95      0.84      2201
     lymphocyte of b lineage       0.99      0.98      0.99      2563
           myeloid leukocyte       0.92      0.96      0.94      2219
                   


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.



                              precision    recall  f1-score   support

    cardiac endothelial cell       0.00      0.00      0.00         0
            contractile cell       0.62      0.88      0.73       884
        endo-epithelial cell       0.01      0.01      0.01       101
            endothelial cell       0.93      0.75      0.83      2783
             epithelial cell       0.88      0.57      0.69       234
     epithelial cell of lung       0.95      0.98      0.96       246
                  fibroblast       0.22      0.84      0.34       811
   glandular epithelial cell       0.59      0.71      0.64        51
                 granulocyte       0.87      0.95      0.91       583
        innate lymphoid cell       0.39      0.91      0.54       560
  intestinal epithelial cell       0.88      0.89      0.88      2201
     lymphocyte of b lineage       0.99      0.98      0.98      2563
           myeloid leukocyte       0.97      0.90      0.94      2219
                   


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.



                              precision    recall  f1-score   support

    cardiac endothelial cell       0.00      0.00      0.00         0
            contractile cell       0.50      0.55      0.52       243
        endo-epithelial cell       0.03      0.67      0.06         6
            endothelial cell       0.91      0.78      0.84       559
             epithelial cell       0.04      0.70      0.08        10
     epithelial cell of lung       0.93      0.89      0.91        97
                  fibroblast       0.30      0.34      0.32       233
   glandular epithelial cell       0.64      0.25      0.36      1470
                 granulocyte       0.88      0.73      0.80       146
        innate lymphoid cell       0.03      0.67      0.06         9
  intestinal epithelial cell       0.00      0.00      0.00         0
     lymphocyte of b lineage       0.70      0.88      0.78       167
           myeloid leukocyte       0.77      0.79      0.78       370
                   


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.



                              precision    recall  f1-score   support

    cardiac endothelial cell       0.00      0.00      0.00         0
            contractile cell       0.13      0.58      0.21       243
        endo-epithelial cell       0.50      0.50      0.50         6
            endothelial cell       0.91      0.92      0.92       559
             epithelial cell       0.26      0.70      0.38        10
     epithelial cell of lung       0.99      0.85      0.91        97
                  fibroblast       0.28      0.61      0.39       233
   glandular epithelial cell       0.78      0.02      0.04      1470
                 granulocyte       0.86      0.78      0.82       146
        innate lymphoid cell       0.16      0.67      0.26         9
  intestinal epithelial cell       0.00      0.00      0.00         0
     lymphocyte of b lineage       0.81      0.90      0.85       167
           myeloid leukocyte       0.76      0.85      0.80       370
                   


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.



                              precision    recall  f1-score   support

    cardiac endothelial cell       0.00      0.00      0.00         0
            contractile cell       0.51      0.56      0.53       243
        endo-epithelial cell       0.02      0.67      0.04         6
            endothelial cell       0.90      0.83      0.86       559
             epithelial cell       0.27      0.70      0.39        10
     epithelial cell of lung       0.99      0.88      0.93        97
                  fibroblast       0.37      0.61      0.46       233
   glandular epithelial cell       0.64      0.06      0.11      1470
                 granulocyte       0.91      0.79      0.84       146
        innate lymphoid cell       0.07      0.89      0.13         9
  intestinal epithelial cell       0.00      0.00      0.00         0
     lymphocyte of b lineage       0.81      0.89      0.85       167
           myeloid leukocyte       0.78      0.80      0.79       370
                   


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.



                              precision    recall  f1-score   support

    cardiac endothelial cell       0.00      0.00      0.00         0
            contractile cell       0.89      0.81      0.85       884
        endo-epithelial cell       0.02      0.03      0.02       101
            endothelial cell       0.92      0.58      0.71      2783
             epithelial cell       0.73      0.57      0.64       234
     epithelial cell of lung       0.92      0.99      0.95       246
                  fibroblast       0.17      0.76      0.28       811
   glandular epithelial cell       0.24      0.71      0.35        51
                 granulocyte       0.85      0.93      0.89       583
        innate lymphoid cell       0.30      0.84      0.44       560
  intestinal epithelial cell       0.85      0.89      0.87      2201
     lymphocyte of b lineage       0.99      0.98      0.99      2563
           myeloid leukocyte       0.89      0.90      0.90      2219
                   


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.



                              precision    recall  f1-score   support

    cardiac endothelial cell       0.00      0.00      0.00         0
            contractile cell       0.68      0.90      0.78       884
        endo-epithelial cell       0.00      0.00      0.00       101
            endothelial cell       0.93      0.84      0.88      2783
             epithelial cell       0.93      0.53      0.67       234
     epithelial cell of lung       0.94      0.98      0.96       246
                  fibroblast       0.21      0.92      0.34       811
   glandular epithelial cell       0.54      0.73      0.62        51
                 granulocyte       0.88      0.95      0.91       583
        innate lymphoid cell       0.69      0.68      0.68       560
  intestinal epithelial cell       0.82      0.95      0.88      2201
     lymphocyte of b lineage       0.99      0.98      0.99      2563
           myeloid leukocyte       0.91      0.96      0.93      2219
                   


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.



                              precision    recall  f1-score   support

    cardiac endothelial cell       0.00      0.00      0.00         0
            contractile cell       0.88      0.85      0.87       884
        endo-epithelial cell       0.02      0.02      0.02       101
            endothelial cell       0.93      0.74      0.82      2783
             epithelial cell       0.84      0.52      0.64       234
     epithelial cell of lung       0.94      0.99      0.96       246
                  fibroblast       0.18      0.80      0.29       811
   glandular epithelial cell       0.43      0.75      0.54        51
                 granulocyte       0.88      0.93      0.90       583
        innate lymphoid cell       0.39      0.88      0.54       560
  intestinal epithelial cell       0.91      0.89      0.90      2201
     lymphocyte of b lineage       0.99      0.98      0.99      2563
           myeloid leukocyte       0.95      0.91      0.93      2219
                   


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.



                              precision    recall  f1-score   support

    cardiac endothelial cell       0.00      0.00      0.00         0
            contractile cell       0.58      0.37      0.45       243
        endo-epithelial cell       0.00      0.00      0.00         6
            endothelial cell       0.90      0.75      0.82       559
             epithelial cell       0.09      0.80      0.16        10
     epithelial cell of lung       0.72      0.90      0.80        97
                  fibroblast       0.32      0.32      0.32       233
   glandular epithelial cell       0.91      0.42      0.58      1470
                 granulocyte       0.87      0.50      0.63       146
        innate lymphoid cell       0.04      0.78      0.07         9
  intestinal epithelial cell       0.00      0.00      0.00         0
     lymphocyte of b lineage       0.70      0.89      0.78       167
           myeloid leukocyte       0.71      0.80      0.75       370
                   


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.



                              precision    recall  f1-score   support

    cardiac endothelial cell       0.00      0.00      0.00         0
            contractile cell       0.60      0.51      0.55       243
        endo-epithelial cell       0.00      0.00      0.00         6
            endothelial cell       0.89      0.93      0.91       559
             epithelial cell       0.13      0.40      0.20        10
     epithelial cell of lung       0.91      0.72      0.80        97
                  fibroblast       0.33      0.86      0.47       233
   glandular epithelial cell       1.00      0.09      0.16      1470
                 granulocyte       0.86      0.64      0.74       146
        innate lymphoid cell       0.19      0.56      0.28         9
  intestinal epithelial cell       0.00      0.00      0.00         0
     lymphocyte of b lineage       0.89      0.89      0.89       167
           myeloid leukocyte       0.71      0.84      0.77       370
                   


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.



                              precision    recall  f1-score   support

    cardiac endothelial cell       0.00      0.00      0.00         0
            contractile cell       0.64      0.47      0.54       243
        endo-epithelial cell       0.00      0.00      0.00         6
            endothelial cell       0.90      0.87      0.88       559
             epithelial cell       0.10      0.70      0.18        10
     epithelial cell of lung       0.93      0.81      0.87        97
                  fibroblast       0.46      0.73      0.56       233
   glandular epithelial cell       0.82      0.82      0.82      1470
                 granulocyte       0.88      0.78      0.83       146
        innate lymphoid cell       0.05      0.78      0.10         9
  intestinal epithelial cell       0.00      0.00      0.00         0
     lymphocyte of b lineage       0.97      0.90      0.93       167
           myeloid leukocyte       0.79      0.81      0.80       370
                   


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.



                              precision    recall  f1-score   support

    cardiac endothelial cell       0.00      0.00      0.00         0
            contractile cell       0.68      0.69      0.69       884
        endo-epithelial cell       0.00      0.00      0.00       101
            endothelial cell       0.90      0.48      0.62      2783
             epithelial cell       0.23      0.30      0.26       234
     epithelial cell of lung       0.85      0.97      0.91       246
                  fibroblast       0.21      0.74      0.33       811
   glandular epithelial cell       0.16      0.31      0.21        51
                 granulocyte       0.89      0.74      0.81       583
        innate lymphoid cell       0.34      0.79      0.48       560
  intestinal epithelial cell       0.92      0.60      0.73      2201
     lymphocyte of b lineage       0.66      0.95      0.78      2563
           myeloid leukocyte       0.92      0.86      0.89      2219
                   


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.



                              precision    recall  f1-score   support

    cardiac endothelial cell       0.00      0.00      0.00         0
            contractile cell       0.64      0.84      0.73       884
        endo-epithelial cell       0.00      0.00      0.00       101
            endothelial cell       0.91      0.79      0.85      2783
             epithelial cell       0.48      0.31      0.38       234
     epithelial cell of lung       0.96      0.93      0.95       246
                  fibroblast       0.24      0.93      0.38       811
   glandular epithelial cell       0.32      0.25      0.28        51
                 granulocyte       0.85      0.86      0.86       583
        innate lymphoid cell       0.62      0.30      0.40       560
  intestinal epithelial cell       0.83      0.88      0.85      2201
     lymphocyte of b lineage       0.97      0.96      0.96      2563
           myeloid leukocyte       0.90      0.94      0.92      2219
                   


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.


Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.



                              precision    recall  f1-score   support

    cardiac endothelial cell       0.00      0.00      0.00         0
            contractile cell       0.62      0.85      0.72       884
        endo-epithelial cell       0.00      0.00      0.00       101
            endothelial cell       0.92      0.69      0.79      2783
             epithelial cell       0.56      0.53      0.55       234
     epithelial cell of lung       0.96      0.96      0.96       246
                  fibroblast       0.23      0.79      0.35       811
   glandular epithelial cell       0.52      0.59      0.55        51
                 granulocyte       0.84      0.90      0.87       583
        innate lymphoid cell       0.35      0.86      0.49       560
  intestinal epithelial cell       0.84      0.85      0.85      2201
     lymphocyte of b lineage       0.98      0.96      0.97      2563
           myeloid leukocyte       0.96      0.91      0.93      2219
                   

In [117]:
results_df.report.iloc[0]["cardiac endothelial cell"]

{'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 0.0}

In [175]:
results_df.to_parquet(data_dir / "algorithm_comparison.parquet")

# Reload results
So that we can skip the training when re-doing analysis

In [23]:
results_df = pd.read_parquet(data_dir / "algorithm_comparison.parquet")

Validate that the support counts match the test counts

In [38]:
support_counts =[
    {key:value["support"] for key,value in results_df.report.iloc[i].items() if key not in ("accuracy", "macro avg", "weighted avg") }
    for i in [1,4,7]
]
train_test_counts.merge(
    pd.DataFrame(support_counts).T,
    how="outer",
    left_index=True,
    right_index=True,
)


Unnamed: 0,count,TSP1,not_TSP1,TSP2,not_TSP2,TSP14,not_TSP14,0,1,2
adventitial cell,0,0,0,0,0,0,0,,,
cardiac endothelial cell,1088,0,1088,0,1088,57,1031,0.0,0.0,57.0
ciliated epithelial cell,0,0,0,0,0,0,0,,,
conjunctival epithelial cell,0,0,0,0,0,0,0,,,
connective tissue cell,0,0,0,0,0,0,0,,,
contractile cell,7916,243,7673,884,7032,633,7283,243.0,884.0,633.0
dendritic cell,0,0,0,0,0,0,0,,,
duct epithelial cell,0,0,0,0,0,0,0,,,
ecto-epithelial cell,0,0,0,0,0,0,0,,,
endo-epithelial cell,666,6,660,101,565,24,642,6.0,101.0,24.0


In [46]:
",".join(str(x) for x in train_test_counts[train_test_counts["count"] != 0].not_TSP14.to_list())

'1031,7283,642,4723,1337,343,5695,1684,1957,1508,4520,5137,5276,5685,13010,9461,5059,2414'

1088,7673,660,5460,1800,722,6762,518,5651,2498,5764,8332,8195,5500,12969,13636,5228,2693+

In [None]:
%load_ext autoreload
%autoreload 2

In [24]:
results_df[results_df.model != "KNN"][results_df.test_donor == "TSP2"]


Boolean Series key will be reindexed to match DataFrame index.



Unnamed: 0,embed_name,test_donor,model,macro_avg_f1,weighted_avg_f1,train_size,test_size,report
4,combined,TSP2,Random Forest,0.59779,0.739929,77289,22711,"{'accuracy': 0.7563295319448725, 'cardiac endo..."
5,combined,TSP2,LightGBM,0.610351,0.756143,77289,22711,"{'accuracy': 0.7396856149002686, 'cardiac endo..."
13,scGPT,TSP2,Random Forest,0.623975,0.758196,77289,22711,"{'accuracy': 0.7674695081678482, 'cardiac endo..."
14,scGPT,TSP2,LightGBM,0.606601,0.759511,77289,22711,"{'accuracy': 0.737175817885606, 'cardiac endot..."
22,GenePT,TSP2,Random Forest,0.496542,0.702569,77289,22711,"{'accuracy': 0.7184183875654969, 'cardiac endo..."
23,GenePT,TSP2,LightGBM,0.53155,0.723403,77289,22711,"{'accuracy': 0.7047245828012857, 'cardiac endo..."


In [165]:
# Extract cell types and metrics from the nested report dictionary
results_list = []
for _, row in results_df[results_df.model != "KNN"].iterrows():
    report = row["report"]
    # Skip the aggregate metrics
    cell_types = [
        k for k in report.keys() if k not in ["accuracy", "macro avg", "weighted avg"]
    ]
    for cell_type in cell_types:
        metrics = report[cell_type]
        results_list.append(
            {
                "cell_type": f"{row['test_donor']} {cell_type} ({metrics['support']:.0f})",
                "embed_name": row["embed_name"],
                "model": row["model"],
                "test_donor": row["test_donor"],
                "precision": metrics["precision"],
                "recall": metrics["recall"],
                "f1-score": metrics["f1-score"],
                "support": metrics["support"],
            }
        )

# Convert to DataFrame and reshape for heatmap
results_flat = pd.DataFrame(results_list)
# First create separate pivots for each metric
metrics = ["precision", "recall", "f1-score"]
pivot_dfs = []

for metric in metrics:
    pivot = results_flat.pivot_table(
        columns="cell_type", index=["embed_name", "model"], values=metric
    )
    # Add metric name to index
    pivot.index = [f"{metric} {idx[0]} {idx[1]}" for idx in pivot.index]
    pivot_dfs.append(pivot)

# Concatenate all metric pivots
results_pivot = pd.concat(pivot_dfs)

In [176]:
results_pivot.to_csv(data_dir / "algorithm_comparison_pivot.csv")

In [171]:
# Create grid of heatmaps
metrics = ["precision", "recall", "f1-score"]
test_donors = results_flat["test_donor"].unique()

# Create subplots grid
fig = make_subplots(
    rows=3,
    cols=3,
    shared_xaxes=True,
    shared_yaxes=True,
    subplot_titles=[f"Donor {donor}" for donor in test_donors],
    vertical_spacing=0.05,
    horizontal_spacing=0.02,
)

# Add each heatmap
for i, metric in enumerate(metrics, 1):
    for j, donor in enumerate(test_donors, 1):
        # Filter data for this metric and donor
        donor_data = results_flat[results_flat["test_donor"] == donor].pivot_table(
            columns="cell_type", index=["embed_name", "model"], values=metric
        )

        # Create y-axis labels (only for leftmost column)
        yaxis_labels = (
            [f"{idx[0]} {idx[1]}" for idx in donor_data.index] if j == 1 else None
        )

        # Add heatmap trace
        heatmap = px.imshow(
            donor_data, color_continuous_scale="RdYlBu", y=yaxis_labels
        ).data[0]

        # Remove black borders and disable hover
        heatmap.update(
            showscale=False, xgap=0, ygap=0, hoverongaps=False, hovertemplate=None
        )

        fig.add_trace(heatmap, row=i, col=j)

        # Update yaxis properties for leftmost column
        if j == 1:
            fig.update_yaxes(
                ticktext=yaxis_labels,
                tickvals=list(range(len(yaxis_labels))),
                row=i,
                col=j,
            )

# Update layout
fig.update_layout(
    width=1500,
    height=1000,
    showlegend=False,
    # Disable zoom/pan
    xaxis=dict(fixedrange=True),
    yaxis=dict(fixedrange=True),
    dragmode=False,
)

# Apply fixed range to all subplots
for i in range(1, 4):
    for j in range(1, 4):
        fig.update_xaxes(fixedrange=True, row=i, col=j)
        fig.update_yaxes(fixedrange=True, row=i, col=j)

# Hide all x-axis labels except bottom row
for i in range(1, 3):  # First two rows
    for j in range(1, 4):  # All columns
        fig.update_xaxes(showticklabels=False, row=i, col=j)

# Show x-axis labels only for bottom row, rotated 90 degrees
for j in range(1, 4):
    fig.update_xaxes(tickangle=90, row=3, col=j)

# Hide y-axis labels except left column
for j in range(2, 4):  # Second and third columns
    for i in range(1, 4):  # All rows
        fig.update_yaxes(showticklabels=False, row=i, col=j)

# Add y-axis titles for left column
for i, metric in enumerate(metrics, 1):
    fig.update_yaxes(title_text=metric.title(), row=i, col=1)

# Update subplot titles position
fig.update_annotations(y=1.05)

fig.show()