In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd

pd.set_option("display.max_rows", 1000)
import sys
from pathlib import Path

repo_dir = Path.cwd().parent.absolute()
sys.path.append(str(repo_dir))

In [3]:
from pathlib import Path

from src.utils import setup_data_dir

setup_data_dir()
data_dir = repo_dir / "data"

File already exists at /root/GenePT-tools/data/GenePT_emebdding_v2.zip
Extracting files...
Extracting GenePT_emebdding_v2/
Skipping GenePT_emebdding_v2/NCBI_UniProt_summary_of_genes.json - already exists with same size
Skipping GenePT_emebdding_v2/GenePT_gene_embedding_ada_text.pickle - already exists with same size
Skipping GenePT_emebdding_v2/GenePT_gene_protein_embedding_model_3_text.pickle. - already exists with same size
Skipping GenePT_emebdding_v2/NCBI_summary_of_genes.json - already exists with same size
Extraction complete!
Setup finished!


In [4]:
import requests

dataset = "https://datasets.cellxgene.cziscience.com/10df7690-6d10-4029-a47e-0f071bb2df83.h5ad"
# dataset_id = "10df7690-6d10-4029-a47e-0f071bb2df83"

file_path = data_dir / "1m_cells.h5ad"  # adjust this path as needed

if not file_path.exists():
    response = requests.get(dataset, stream=True)
    with open(file_path, "wb") as file:
        for chunk in response.iter_content(chunk_size=8192):
            if chunk:  # filter out keep-alive chunks
                file.write(chunk)

In [5]:
from src.utils import load_subset_anndata

# Load with specific obs metadata columns
adata_filtered = load_subset_anndata(
    file_path,
    start_row=0,
    n_rows=100000,
    obs_columns=["cell_type", "broad_cell_class", "donor_id"],
)

print("AnnData shape:", adata_filtered.shape)
print("Feature metadata columns:", adata_filtered.var.columns)  # Print all var metadata
print(
    "Selected Observation metadata columns:", adata_filtered.obs.columns
)  # Print selected obs metadata
print(
    "Matrix density:",
    adata_filtered.X.nnz / (adata_filtered.shape[0] * adata_filtered.shape[1]),
)



AnnData shape: (1136218, 61759)
Feature metadata columns: Index(['ensembl_id', 'ensg', 'ercc', 'feature_biotype', 'feature_is_filtered',
       'feature_length', 'feature_name', 'feature_reference', 'feature_type',
       'genome', 'mean', 'mean_counts', 'mt', 'n_cells_by_counts',
       'pct_dropout_by_counts', 'std', 'total_counts'],
      dtype='object')
Selected Observation metadata columns: Index(['cell_type', 'broad_cell_class', 'donor_id'], dtype='object')
Matrix density: 0.05041241182799931


In [None]:
model_dir = Path("/root/scGPT/save/scGPT_human")
gene_col = "feature_name"
import scgpt as scg

ref_embed_adata = scg.tasks.embed_data(
    adata_filtered,
    model_dir,
    gene_col=gene_col,
    batch_size=64,
)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# We have to set the index.name to None because it conflicts with the name
# of the feature_name column for some reason...
ref_embed_adata.var.index.name = None
# ref_embed_adata.write("../data/tabula_sapiens_100k_scgpt_embedding.h5ad")
ref_embed_adata.write("../data/tabula_sapiens_all_scgpt_embedding.h5ad")

In [8]:
ref_embed_adata.obs

Unnamed: 0,cell_type,broad_cell_class,donor_id
0,"naive thymus-derived CD4-positive, alpha-beta ...",t cell,TSP2
1,B cell,lymphocyte of b lineage,TSP2
2,B cell,lymphocyte of b lineage,TSP2
3,B cell,lymphocyte of b lineage,TSP2
4,"CD8-positive, alpha-beta T cell",t cell,TSP2
...,...,...,...
99995,endothelial cell of artery,endothelial cell,TSP2
99996,mesenchymal stem cell,stem cell,TSP2
99997,pericyte,contractile cell,TSP2
99998,skeletal muscle satellite stem cell,stem cell,TSP2


In [9]:
ref_embed_obs_pdf = ref_embed_adata.obs
ref_embed_obs_pdf.index = pd.RangeIndex(start=0, stop=100000, step=1)

In [11]:
embed_scgpt_pdf = pd.DataFrame(ref_embed_adata.obsm["X_scGPT"]).merge(
    ref_embed_obs_pdf, left_index=True, right_index=True
)
embed_scgpt_pdf.columns = [str(col) for col in embed_scgpt_pdf.columns]
embed_scgpt_pdf

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,505,506,507,508,509,510,511,cell_type,broad_cell_class,donor_id
0,-0.033486,-0.054066,-0.029279,-0.051283,0.027736,-0.019041,0.014046,-0.020791,0.015448,-0.010465,...,0.029240,0.006003,-0.007707,0.023624,-0.052029,-0.019846,-0.001054,"naive thymus-derived CD4-positive, alpha-beta ...",t cell,TSP2
1,-0.022117,-0.022753,-0.043033,-0.036327,0.055331,-0.022720,0.004311,-0.037500,-0.001853,-0.009056,...,0.056050,0.013014,-0.009980,-0.009603,-0.058539,-0.016580,-0.013383,B cell,lymphocyte of b lineage,TSP2
2,-0.007590,-0.007243,-0.041140,-0.037873,0.032792,-0.006948,0.001643,-0.031062,-0.011355,-0.019418,...,0.042358,0.004911,-0.003878,-0.015848,-0.049975,-0.007029,0.002340,B cell,lymphocyte of b lineage,TSP2
3,-0.028240,-0.034682,-0.030902,-0.033507,0.057636,-0.023642,0.013100,-0.034982,0.005553,-0.009135,...,0.054909,0.010684,-0.009662,-0.007196,-0.059451,-0.018297,-0.003425,B cell,lymphocyte of b lineage,TSP2
4,-0.021991,-0.028865,-0.010812,-0.050530,0.050317,-0.013198,0.014578,-0.033548,0.013169,-0.000067,...,0.041926,0.016854,-0.005093,-0.007156,-0.049383,-0.029127,0.002126,"CD8-positive, alpha-beta T cell",t cell,TSP2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
99995,0.006164,0.004585,-0.044006,-0.011603,-0.024295,-0.053256,-0.023546,-0.020517,0.020568,-0.010009,...,-0.046977,-0.033839,-0.051678,0.018070,-0.014599,-0.033662,-0.000064,endothelial cell of artery,endothelial cell,TSP2
99996,0.010340,0.018859,-0.064489,-0.046560,-0.017704,-0.047513,-0.005364,-0.021554,0.020964,0.020232,...,0.002843,-0.044125,-0.025023,0.022193,-0.025148,-0.016971,0.010183,mesenchymal stem cell,stem cell,TSP2
99997,0.004445,0.045190,-0.011946,-0.045088,-0.001150,-0.036768,-0.034440,-0.028356,0.010642,-0.016961,...,-0.007213,-0.028585,-0.043290,0.028288,-0.021030,0.016056,0.006466,pericyte,contractile cell,TSP2
99998,-0.009389,0.025690,-0.055871,-0.051706,0.020641,-0.044475,0.002695,-0.027326,0.029309,0.001589,...,0.004108,-0.043317,-0.014888,0.014512,-0.040493,-0.026234,0.002629,skeletal muscle satellite stem cell,stem cell,TSP2


In [12]:
embed_scgpt_pdf.to_parquet("../data/tabula_sapiens_100k_scgpt_embedding.parquet")

In [13]:
embed_scgpt_pdf = pd.read_parquet("../data/tabula_sapiens_100k_scgpt_embedding.parquet")
embed_scgpt_pdf.columns

Index(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
       ...
       '505', '506', '507', '508', '509', '510', '511', 'cell_type',
       'broad_cell_class', 'donor_id'],
      dtype='object', length=515)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import plotly.express as px
import umap
from sklearn.decomposition import PCA

# Convert all column names to strings before applying PCA
# features_24_weeks_df.columns = features_24_weeks_df.columns.astype(str)

# Now run PCA
# pca = PCA(n_components=50)
# pca_embeddings = pca.fit_transform(features_24_weeks_df.drop(columns=['drug dose', 'external_id', 'ifn status', "subject sex", "subject age", "subject_id"]))
# Then apply UMAP to the PCA results
reducer = umap.UMAP(random_state=42)
random_indices = np.random.choice(
    ref_embed_adata.obsm["X_scGPT"].shape[0], size=2000, replace=False
)
umap_sample_pdf = ref_embed_adata.obsm["X_scGPT"][random_indices]
umap_embeddings = reducer.fit_transform(umap_sample_pdf)

In [None]:
# Create a DataFrame with the UMAP coordinates using the original DataFrame's index
umap_df = pd.DataFrame(umap_embeddings, columns=["UMAP1", "UMAP2"])

umap_df["cell_type"] = adata_filtered.obs.broad_cell_class.iloc[
    random_indices
].cat.codes.to_numpy()

# Create the plot
fig = px.scatter(
    umap_df,
    x="UMAP1",
    y="UMAP2",
    color="cell_type",
    opacity=0.7,
    title="UMAP Visualization of Gene Expression Embeddings",
)

# Update layout
fig.update_layout(title={"y": 0.95, "x": 0.5, "xanchor": "center", "yanchor": "top"})

fig.show()

In [None]:
import plotly.express as px

px.histogram(ref_embed_adata.obs.broad_cell_class.sort_values())

In [None]:
# Create a cross-tabulation of donor_id and cell_type
heatmap_data = pd.crosstab(
    ref_embed_adata.obs.donor_id, ref_embed_adata.obs.broad_cell_class
)

# Create heatmap using plotly
import numpy as np
import plotly.express as px

# Apply log10 transform to the data (adding 1 to avoid log(0))
log_data = np.log10(heatmap_data.values + 1)

# Create regular heatmap with log-transformed data
fig = px.imshow(
    log_data,
    labels=dict(x="Cell Type", y="Donor ID", color="Count"),
    x=heatmap_data.columns,
    y=heatmap_data.index,
    color_continuous_scale="Viridis",
    title="Cell Type Distribution Across Donors (Log Scale)",
    aspect="auto",
)

# Update hover template to show both log and linear values
fig.data[0].customdata = heatmap_data.values
fig.data[0].hovertemplate = (
    "Cell Type: %{x}<br>Donor ID: %{y}<br>Count: %{customdata:.0f}<br>Log10 Count: %{z:.2f}<extra></extra>"
)

# Create tick values for the colorbar (in log space)
tick_values = np.linspace(log_data.min(), log_data.max(), 6)
# Convert tick values back to linear space for labels
tick_labels = [f"{int(10**x - 1)}" for x in tick_values]

# Update layout and colorbar
fig.update_layout(
    xaxis_title="Cell Type",
    yaxis_title="Donor ID",
    height=700,  # Adjusted height (increase as needed)
    coloraxis=dict(
        colorbar=dict(title="Count", tickvals=tick_values, ticktext=tick_labels)
    ),
)

fig.show()

In [None]:
# Get value counts and identify categories with < 200 samples
category_counts = pd.Series(ref_embed_adata.obs.broad_cell_class.value_counts())
small_categories = category_counts[category_counts < 200].index

# Create new column with remapped categories
cell_type_grouped = cell_embeddings_pdf.cell_type
cell_embeddings_pdf.loc[
    cell_embeddings_pdf.cell_type.isin(small_categories), "cell_type_grouped"
] = (max(cell_embeddings_pdf.cell_type) + 1)

In [None]:
X = pd.DataFrame(ref_embed_adata.obsm["X_scGPT"])
y = ref_embed_adata.obs["broad_cell_class"]
X["donor_id"] = ref_embed_adata.obs.donor_id.cat.codes.to_numpy()

# print("Shape of embedding features indicator:", embedding_features_indicator.shape)
print("Shape of filtered features matrix:", X.shape)

In [None]:
y == "t cell"

In [None]:
# from sklearn.model_selection import GroupShuffleSplit

# # Create group-wise split
# gss = GroupShuffleSplit(n_splits=1, test_size=0.3, random_state=42)
# train_idx, test_idx = next(gss.split(X, y, groups=X.donor_id))

# # Split the data using the indices
# X_train = X.drop(columns=['donor_id']).iloc[train_idx]
# X_test = X.drop(columns=['donor_id']).iloc[test_idx]
# y_train = y.iloc[train_idx]
# y_test = y.iloc[test_idx]

In [None]:
(y == "t cell").index

In [None]:
(X.donor_id != test_donor).index

In [None]:
from lightgbm import LGBMClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
from sklearn.neighbors import KNeighborsClassifier

# Define the donors we want to evaluate
test_donors = [0, 1, 13]
results = []

# Perform cross-validation, holding out one donor at a time
for test_donor in test_donors:
    print(f"\n=== Cross Validation Fold: Testing on Donor {test_donor} ===")

    # Create initial train/test split based on donor
    train_mask = X.donor_id != test_donor
    test_indices = X[~train_mask].index

    # Subsample training data to get 200 samples per cell type
    train_indices = []
    for class_label in y.unique():
        class_index = (y == class_label).index.astype(int)

        # Get the boolean mask for both conditions
        class_mask = (X.donor_id != test_donor) & y.index.isin(class_index)

        # Get indices for this class from non-test donors
        class_indices = X[class_mask].index
        print(f"y == class_label positives = {(y == class_label).sum()}")
        print(f"train_mask positives = {train_mask.sum()}")
        print(f"class_mask positives = {class_mask.sum()}")
        # Randomly sample up to 1000 indices
        if len(class_indices) > 0:
            n_samples = min(1000, len(class_indices))
            sampled_indices = np.random.choice(
                class_indices, size=n_samples, replace=False
            )
            train_indices.extend(sampled_indices)
        else:
            print(f"warning: class '{class_label}' has no samples!")

    # Create the final train/test splits
    X_train = X.drop(columns=["donor_id"]).iloc[train_indices]
    X_test = X.drop(columns=["donor_id"]).iloc[test_indices]
    y_train = y.iloc[train_indices]
    y_test = y.iloc[test_indices]

    print(f"Training set size: {len(X_train)}")
    print(f"Test set size: {len(X_test)}")
    print("\nTraining class distribution:")
    print(y_train.value_counts().sort_index())

    # Train and evaluate models
    models = {
        "KNN": KNeighborsClassifier(n_neighbors=10),
        "Random Forest": RandomForestClassifier(random_state=42),
        "LightGBM": LGBMClassifier(random_state=42, class_weight="balanced"),
    }

    for name, model in models.items():
        print(f"\n{name} Results:")
        print("-" * 50)
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)

        valid_classes = sorted(set(y_test))
        # Generate report only for classes that exist in the data
        report = classification_report(
            y_test,
            y_pred,
            labels=valid_classes,
            zero_division=0,
            output_dict=True,
        )
        # Store results
        results.append(
            {
                "test_donor": test_donor,
                "model": name,
                # 'accuracy': report['accuracy'],
                "macro_avg_f1": report["macro avg"]["f1-score"],
                "weighted_avg_f1": report["weighted avg"]["f1-score"],
                "train_size": len(X_train),
                "test_size": len(X_test),
            }
        )

        print(classification_report(y_test, y_pred))

# Convert results to DataFrame for easy viewing
results_df = pd.DataFrame(results)
print("\nSummary of Results:")
print(results_df.round(3))

In [None]:
train_mask.sum()

In [None]:
report

In [None]:
with h5py.File(file_path, "r") as f:
    cell_class = pd.Series(
        f["obs"]["broad_cell_class"]["categories"], name="cell_class_name"
    )

In [None]:
lgbm_results_donor_7 = pd.DataFrame(
    {
        "class": [
            0,
            1,
            2,
            4,
            5,
            6,
            8,
            9,
            10,
            11,
            12,
            15,
            17,
            19,
            20,
            22,
            23,
            25,
            28,
            31,
            32,
            34,
            35,
            36,
            37,
        ],
        "precision": [
            0.33,
            0.0,
            0.44,
            0.61,
            0.51,
            0.0,
            0.0,
            0.87,
            0.36,
            0.9,
            0.41,
            0.21,
            0.24,
            0.81,
            0.21,
            0.32,
            0.85,
            0.9,
            0.95,
            0.11,
            0.54,
            0.01,
            0.96,
            0.02,
            0.17,
        ],
        "recall": [
            0.12,
            0.0,
            0.8,
            0.78,
            0.48,
            0.0,
            0.0,
            0.63,
            0.5,
            0.97,
            0.41,
            0.68,
            0.29,
            0.81,
            0.84,
            0.83,
            0.8,
            0.96,
            0.89,
            0.15,
            0.13,
            0.19,
            0.75,
            0.08,
            0.03,
        ],
        "f1_score": [
            0.18,
            0.0,
            0.57,
            0.68,
            0.5,
            0.0,
            0.0,
            0.73,
            0.42,
            0.94,
            0.41,
            0.32,
            0.26,
            0.81,
            0.34,
            0.46,
            0.82,
            0.93,
            0.92,
            0.12,
            0.21,
            0.02,
            0.85,
            0.04,
            0.05,
        ],
        "support": [
            258,
            0,
            5,
            884,
            81,
            0,
            101,
            2783,
            234,
            246,
            22,
            811,
            51,
            583,
            38,
            560,
            2201,
            2563,
            2219,
            27,
            3861,
            16,
            4910,
            97,
            160,
        ],
    }
)

In [None]:
lgbm_results_donor_7.merge(cell_class, left_index=True, right_index=True)