# CellStrata — census_query Tutorial Notebook

This notebook walks through the full `census_query` workflow:

1. **Configure** a query with YAML or Python dataclasses
2. **Run** the query against the CELLxGENE Census
3. **Inspect** the returned metadata DataFrame
4. **Visualize** metadata distributions with the `_visualize` module

> **Disclaimer:** AI was used to help write this notebook. Please review carefully.

---

## 0. Setup

Make sure the CellStrata package directory is on your Python path so
imports work from inside the `Notebooks/` folder.

In [None]:
import sys, pathlib

# Add the CellStrata package root to sys.path
pkg_root = str(pathlib.Path.cwd().parent)  # assumes notebook is in Notebooks/
if pkg_root not in sys.path:
    sys.path.insert(0, pkg_root)

print(f"Package root: {pkg_root}")

In [None]:
# Core imports
import pandas as pd
import matplotlib.pyplot as plt

# CellStrata imports
from census_query import (
    CensusTarget,
    ObsFilters,
    OutputSpec,
    QuerySpec,
    load_query_spec_yaml,
    run_query,
    # Visualization
    plot_cell_type_counts,
    plot_tissue_counts,
    plot_sex_distribution,
    plot_assay_counts,
    plot_disease_counts,
    plot_dataset_contribution,
    plot_development_stage_counts,
    plot_donors_per_dataset,
    plot_cell_type_by_tissue,
    plot_metadata_summary,
)

print("Imports OK")

---
## 1. Configure the Query

You can define the query either by loading a YAML config file or by
constructing a `QuerySpec` in Python. Both approaches are shown below.

### Option A — Load from YAML

In [None]:
# Load the default config shipped with the project
spec_yaml = load_query_spec_yaml("../config/census_query.yaml")

print(f"Census version : {spec_yaml.target.census_version}")
print(f"Organism       : {spec_yaml.target.organism}")
print(f"Output mode    : {spec_yaml.output.mode}")
print(f"Filters        : is_primary_data={spec_yaml.obs_filters.is_primary_data}, "
      f"suspension_type={spec_yaml.obs_filters.suspension_type}")
print(f"Disease IDs    : {spec_yaml.obs_filters.disease_ontology_term_ids}")
print(f"Assay labels   : {spec_yaml.obs_filters.assay_labels}")
print(f"Sex labels     : {spec_yaml.obs_filters.sex_labels}")

### Option B — Build in Python

This is useful when you want to parameterize queries or run several
variants in a loop.

In [None]:
spec = QuerySpec(
    target=CensusTarget(
        census_version="stable",
        organism="homo_sapiens",
    ),
    obs_filters=ObsFilters(
        is_primary_data=True,
        suspension_type="cell",
        disease_ontology_term_ids=["PATO:0000461"],  # healthy / normal
        assay_labels=["10x 3' v2", "10x 3' v3"],
        sex_labels=["male", "female"],
        # Uncomment to narrow the query (faster for testing):
        # tissue_general_labels=["lung"],
        # cell_type_labels=["mast cell"],
    ),
    output=OutputSpec(mode="pandas"),
    # TileDB config to avoid S3 timeouts on the cluster
    tiledb_config={
        "vfs.s3.connect_timeout_ms": 60000,
        "vfs.s3.request_timeout_ms": 600000,
        "vfs.s3.max_parallel_ops": 2,
    },
)

print("QuerySpec created")
print(f"  Organism  : {spec.target.organism}")
print(f"  Mode      : {spec.output.mode}")

---
## 2. Run the Query

This connects to the CELLxGENE Census (remote S3), applies all filters
server-side, and downloads only the matching cell metadata as a pandas
DataFrame.

> **Note:** The first run may take a few minutes depending on filter
> breadth and network speed. A narrow filter (e.g. one tissue + one
> cell type) finishes much faster.

In [None]:
%%time
df = run_query(spec)

print(f"\nReturned {len(df):,} cells")
print(f"Columns: {list(df.columns)}")

---
## 3. Inspect the Metadata

Quick sanity checks before plotting.

In [None]:
df.head()

In [None]:
df.shape

In [None]:
print(f"Unique donors   : {df['donor_id'].nunique():,}")
print(f"Unique datasets : {df['dataset_id'].nunique():,}")
print(f"Unique cell types: {df['cell_type'].nunique():,}")
print(f"Unique tissues  : {df['tissue_general'].nunique():,}")

In [None]:
# Top 10 cell types by count
df["cell_type"].value_counts().head(10)

In [None]:
# Cells per tissue (coarse)
df["tissue_general"].value_counts().head(10)

---
## 4. Visualize — Individual Plots

Each function returns a `matplotlib.Axes` so you can customize further.

In [None]:
plot_cell_type_counts(df, top_n=20)
plt.tight_layout()
plt.show()

In [None]:
plot_tissue_counts(df, column="tissue_general", top_n=20)
plt.tight_layout()
plt.show()

In [None]:
plot_sex_distribution(df)
plt.tight_layout()
plt.show()

In [None]:
plot_assay_counts(df)
plt.tight_layout()
plt.show()

In [None]:
plot_disease_counts(df)
plt.tight_layout()
plt.show()

In [None]:
plot_dataset_contribution(df, top_n=15)
plt.tight_layout()
plt.show()

In [None]:
plot_development_stage_counts(df, top_n=10)
plt.tight_layout()
plt.show()

In [None]:
plot_donors_per_dataset(df, top_n=15)
plt.tight_layout()
plt.show()

### Cell-type-by-tissue heatmap

Colour encodes `log10(count + 1)` so rare populations stay visible;
cell annotations show raw counts.

In [None]:
plot_cell_type_by_tissue(df, top_cell_types=12, top_tissues=8)
plt.tight_layout()
plt.show()

---
## 5. Visualize — Summary Dashboard

`plot_metadata_summary` produces a 3 x 2 grid combining six panels in
one figure. Pass `save_path` to write it to disk.

In [None]:
fig = plot_metadata_summary(df, top_n=15, save_path="metadata_summary.png")
plt.show()

---
## 6. Custom Multi-Panel Figure

Because every plot function accepts an `ax` parameter, you can compose
any layout you like.

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(20, 5))

plot_cell_type_counts(df, top_n=10, ax=axes[0])
plot_tissue_counts(df, top_n=10, ax=axes[1])
plot_sex_distribution(df, ax=axes[2])

fig.suptitle(f"Metadata overview  ({len(df):,} cells)", fontsize=14)
fig.tight_layout()
plt.show()

---
## 7. Tabular Summaries

Quick cross-tabulations that complement the plots.

In [None]:
# Cells and donors per tissue
df.groupby("tissue_general")["donor_id"].agg(
    cells="size",
    donors="nunique",
).sort_values("cells", ascending=False).head(15)

In [None]:
# Cells per dataset and assay
pd.crosstab(df["dataset_id"], df["assay"]).head(10)

---
## 8. Narrower Query — Mast Cells in Lung

Demonstrate a more targeted query that runs faster.

In [None]:
spec_narrow = QuerySpec(
    target=CensusTarget(census_version="stable", organism="homo_sapiens"),
    obs_filters=ObsFilters(
        is_primary_data=True,
        suspension_type="cell",
        disease_ontology_term_ids=["PATO:0000461"],
        assay_labels=["10x 3' v3"],
        tissue_general_labels=["lung"],
        cell_type_labels=["mast cell"],
    ),
    output=OutputSpec(mode="pandas"),
    tiledb_config={
        "vfs.s3.connect_timeout_ms": 60000,
        "vfs.s3.request_timeout_ms": 600000,
    },
)

print("Narrow query spec created (lung mast cells, 10x 3' v3, healthy)")

In [None]:
%%time
df_mast = run_query(spec_narrow)
print(f"\nReturned {len(df_mast):,} lung mast cells")
print(f"Unique donors  : {df_mast['donor_id'].nunique()}")
print(f"Unique datasets: {df_mast['dataset_id'].nunique()}")

In [None]:
df_mast.head()

In [None]:
fig = plot_metadata_summary(df_mast, top_n=10, save_path="mast_cell_lung_summary.png")
plt.show()

---
## 9. Dataset List Mode

Identify which CELLxGENE datasets contain matching cells without
downloading all the metadata.

In [None]:
spec_ds = QuerySpec(
    target=CensusTarget(census_version="stable", organism="homo_sapiens"),
    obs_filters=ObsFilters(
        is_primary_data=True,
        disease_ontology_term_ids=["PATO:0000461"],
        cell_type_labels=["mast cell"],
        tissue_general_labels=["lung"],
    ),
    output=OutputSpec(mode="dataset_list"),
    tiledb_config={
        "vfs.s3.connect_timeout_ms": 60000,
        "vfs.s3.request_timeout_ms": 600000,
    },
)

dataset_ids = run_query(spec_ds)

print(f"\n{len(dataset_ids)} datasets contain lung mast cells:")
for ds_id in dataset_ids:
    print(f"  {ds_id}")

---
## 10. Next Steps

From here you can:

- Switch to `mode="anndata"` to download expression data and feed it
  into a scanpy pipeline (QC, normalization, clustering, UMAP).
- Use the `dataset_list` output to select specific datasets for
  deeper analysis.
- Save the DataFrame to CSV/Parquet for use in other tools:

```python
df.to_csv("metadata.csv", index=False)
df.to_parquet("metadata.parquet", index=False)
```