# CellStrata — census_query Tutorial Notebook

This notebook walks through the full `census_query` workflow using
**mast cells in healthy human lung** as the running example:

1. **Configure** a query with YAML or Python dataclasses
2. **Run** the query against the CELLxGENE Census
3. **Inspect** the returned metadata DataFrame
4. **Visualize** metadata distributions with the `_visualize` module
5. **Stream** results to Arrow / Parquet for large-scale filtering
6. **Down-sample** to a single dataset for rapid prototyping

> **Disclaimer:** AI was used to help write this notebook. Please review carefully.

---

## 0. Setup

Make sure the CellStrata package directory is on your Python path so
imports work from inside the `Notebooks/` folder.

In [None]:
import sys, pathlib

# Add the CellStrata package root to sys.path
pkg_root = str(pathlib.Path.cwd().parent)  # assumes notebook is in Notebooks/
if pkg_root not in sys.path:
    sys.path.insert(0, pkg_root)

print(f"Package root: {pkg_root}")

In [None]:
# Core imports
import pandas as pd
import matplotlib.pyplot as plt

# CellStrata imports
from census_query import (
    CensusTarget,
    ObsFilters,
    OutputSpec,
    QuerySpec,
    load_query_spec_yaml,
    run_query,
    # I/O — Arrow / Parquet streaming
    stream_obs_tables,
    write_parquet_stream,
    _resolve_outpath,
    # Visualization
    plot_cell_type_counts,
    plot_tissue_counts,
    plot_sex_distribution,
    plot_assay_counts,
    plot_disease_counts,
    plot_dataset_contribution,
    plot_development_stage_counts,
    plot_donors_per_dataset,
    plot_cell_type_by_tissue,
    plot_metadata_summary,
)

print("Imports OK")

---
## 1. Configure the Query

You can define the query either by loading a YAML config file or by
constructing a `QuerySpec` in Python. Both approaches are shown below.

### Option A — Load from YAML

In [None]:
# Load the default config shipped with the project
spec_yaml = load_query_spec_yaml("../config/census_query.yaml")

print(f"Census version : {spec_yaml.target.census_version}")
print(f"Organism       : {spec_yaml.target.organism}")
print(f"Output mode    : {spec_yaml.output.mode}")
print(f"Filters        : is_primary_data={spec_yaml.obs_filters.is_primary_data}, "
      f"suspension_type={spec_yaml.obs_filters.suspension_type}")
print(f"Disease IDs    : {spec_yaml.obs_filters.disease_ontology_term_ids}")
print(f"Assay labels   : {spec_yaml.obs_filters.assay_labels}")
print(f"Sex labels     : {spec_yaml.obs_filters.sex_labels}")

### Option B — Build in Python

This is useful when you want to parameterize queries or run several
variants in a loop.

In [None]:
spec = QuerySpec(
    target=CensusTarget(
        census_version="stable",
        organism="homo_sapiens",
    ),
    obs_filters=ObsFilters(
        is_primary_data=True,
        suspension_type="cell",
        disease_ontology_term_ids=["PATO:0000461"],  # healthy / normal
        tissue_general_labels=["lung"],
        cell_type_labels=["mast cell"],
    ),
    output=OutputSpec(mode="pandas"),
    # TileDB config to avoid S3 timeouts on the cluster
    tiledb_config={
        "vfs.s3.connect_timeout_ms": 60000,
        "vfs.s3.request_timeout_ms": 600000,
        "vfs.s3.max_parallel_ops": 2,
    },
)

print("QuerySpec created — mast cells in healthy lung")
print(f"  Organism  : {spec.target.organism}")
print(f"  Mode      : {spec.output.mode}")
print(f"  Filters   : disease=normal, tissue=lung, cell_type=mast cell")

---
## 2. Run the Query

This connects to the CELLxGENE Census (remote S3), applies all filters
server-side, and downloads only the matching cell metadata as a pandas
DataFrame.

> **Note:** The first run may take a few minutes depending on filter
> breadth and network speed. The mast-cells-in-lung query is quite
> narrow, so it should finish quickly.

In [None]:
%%time
df = run_query(spec)

print(f"\nReturned {len(df):,} mast cells from healthy lung")
print(f"Columns: {list(df.columns)}")

---
## 3. Inspect the Metadata

Quick sanity checks before plotting.

In [None]:
df.head()

In [None]:
df.shape

In [None]:
print(f"Unique donors   : {df['donor_id'].nunique():,}")
print(f"Unique datasets : {df['dataset_id'].nunique():,}")
print(f"Unique tissues  : {df['tissue_general'].nunique():,}")
print(f"Unique assays   : {df['assay'].nunique():,}")

In [None]:
# Assay breakdown for lung mast cells
df["assay"].value_counts()

In [None]:
# Cells per dataset
df["dataset_id"].value_counts()

---
## 4. Visualize — Individual Plots

Each function returns a `matplotlib.Axes` so you can customize further.

In [None]:
plot_cell_type_counts(df, top_n=20)
plt.tight_layout()
plt.show()

In [None]:
plot_tissue_counts(df, column="tissue_general", top_n=20)
plt.tight_layout()
plt.show()

In [None]:
plot_sex_distribution(df)
plt.tight_layout()
plt.show()

In [None]:
plot_assay_counts(df)
plt.tight_layout()
plt.show()

In [None]:
plot_disease_counts(df)
plt.tight_layout()
plt.show()

In [None]:
plot_dataset_contribution(df, top_n=15)
plt.tight_layout()
plt.show()

In [None]:
plot_development_stage_counts(df, top_n=10)
plt.tight_layout()
plt.show()

In [None]:
plot_donors_per_dataset(df, top_n=15)
plt.tight_layout()
plt.show()

### Cell-type-by-tissue heatmap

Since our query is already filtered to mast cells in lung, this heatmap
shows the cross-tabulation of fine-grained tissue sites vs. assays.
For a broader query it would show cell types across tissues.

In [None]:
plot_cell_type_by_tissue(df, top_cell_types=12, top_tissues=8)
plt.tight_layout()
plt.show()

---
## 5. Visualize — Summary Dashboard

`plot_metadata_summary` produces a 3 x 2 grid combining six panels in
one figure. Pass `save_path` to write it to disk.

In [None]:
fig = plot_metadata_summary(df, top_n=15, save_path="mast_cell_lung_summary.png")
plt.show()

---
## 6. Custom Multi-Panel Figure

Because every plot function accepts an `ax` parameter, you can compose
any layout you like.

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(20, 5))

plot_dataset_contribution(df, top_n=10, ax=axes[0])
plot_sex_distribution(df, ax=axes[1])
plot_assay_counts(df, ax=axes[2])

fig.suptitle(f"Lung mast cells  ({len(df):,} cells)", fontsize=14)
fig.tight_layout()
plt.show()

---
## 7. Tabular Summaries

Quick cross-tabulations that complement the plots.

In [None]:
# Cells and donors per dataset
df.groupby("dataset_id")["donor_id"].agg(
    cells="size",
    donors="nunique",
).sort_values("cells", ascending=False)

In [None]:
# Cells per assay and sex
pd.crosstab(df["assay"], df["sex"])

---
## 8. Down-Sampled Query — Single Dataset

For rapid prototyping you can restrict the query to a single known
dataset using `extra_value_filter`. This is much faster than querying
the full Census and gives you a reproducible, fixed-size slice.

We use dataset **`1b350d0a-4535-4879-beb6-1142f3f94947`** as our
down-sampled reference.

In [None]:
DATASET_ID = "1b350d0a-4535-4879-beb6-1142f3f94947"

spec_ds = QuerySpec(
    target=CensusTarget(census_version="stable", organism="homo_sapiens"),
    obs_filters=ObsFilters(
        is_primary_data=True,
        suspension_type="cell",
        extra_value_filter=f"dataset_id == '{DATASET_ID}'",
    ),
    output=OutputSpec(mode="pandas"),
    tiledb_config={
        "vfs.s3.connect_timeout_ms": 60000,
        "vfs.s3.request_timeout_ms": 600000,
    },
)

print(f"Down-sampled query spec created")
print(f"  dataset_id: {DATASET_ID}")

In [None]:
%%time
df_ds = run_query(spec_ds)

print(f"\nReturned {len(df_ds):,} cells from dataset {DATASET_ID}")
print(f"Unique cell types: {df_ds['cell_type'].nunique()}")
print(f"Unique donors    : {df_ds['donor_id'].nunique()}")

In [None]:
df_ds.head()

In [None]:
# Cell type breakdown for the down-sampled dataset
df_ds["cell_type"].value_counts().head(15)

In [None]:
fig = plot_metadata_summary(df_ds, top_n=10, save_path="downsampled_dataset_summary.png")
plt.show()

---
## 9. Dataset List Mode

Identify which CELLxGENE datasets contain lung mast cells without
downloading all the metadata.

In [None]:
spec_dslist = QuerySpec(
    target=CensusTarget(census_version="stable", organism="homo_sapiens"),
    obs_filters=ObsFilters(
        is_primary_data=True,
        disease_ontology_term_ids=["PATO:0000461"],
        cell_type_labels=["mast cell"],
        tissue_general_labels=["lung"],
    ),
    output=OutputSpec(mode="dataset_list"),
    tiledb_config={
        "vfs.s3.connect_timeout_ms": 60000,
        "vfs.s3.request_timeout_ms": 600000,
    },
)

dataset_ids = run_query(spec_dslist)

print(f"\n{len(dataset_ids)} datasets contain healthy lung mast cells:")
for ds_id in dataset_ids:
    print(f"  {ds_id}")

---
## 10. Arrow / Parquet Streaming

For large queries — or when you want to write results directly to disk
without loading everything into memory — use the `arrow` or `parquet`
output modes. These stream Arrow RecordBatches from the Census and
either concatenate them into an in-memory Arrow Table (`arrow` mode)
or write them incrementally to a Parquet file (`parquet` mode).

> **Note:** We use the down-sampled single dataset here to keep run
> times short. In production you would use these modes for large,
> broad queries (e.g. all cell types across all tissues).

### 10a. Arrow mode — in-memory Arrow Table

In [None]:
%%time
import pyarrow as pa

# Re-use the down-sampled dataset filter for a fast Arrow demo
spec_arrow = QuerySpec(
    target=CensusTarget(census_version="stable", organism="homo_sapiens"),
    obs_filters=ObsFilters(
        is_primary_data=True,
        suspension_type="cell",
        extra_value_filter=f"dataset_id == '{DATASET_ID}'",
    ),
    output=OutputSpec(mode="arrow"),
    tiledb_config={
        "vfs.s3.connect_timeout_ms": 60000,
        "vfs.s3.request_timeout_ms": 600000,
    },
)

arrow_table = run_query(spec_arrow)

print(f"Arrow Table: {arrow_table.num_rows:,} rows x {arrow_table.num_columns} columns")
print(f"Schema:\n{arrow_table.schema}")

# Convert to pandas when ready
df_from_arrow = arrow_table.to_pandas()
df_from_arrow.head()

### 10b. Parquet mode — stream to disk

Parquet mode writes results directly to a Parquet file without loading
all data into memory. This is ideal for very large queries.

In [None]:
%%time
# Re-use the down-sampled dataset filter for a fast Parquet demo
spec_parquet = QuerySpec(
    target=CensusTarget(census_version="stable", organism="homo_sapiens"),
    obs_filters=ObsFilters(
        is_primary_data=True,
        suspension_type="cell",
        extra_value_filter=f"dataset_id == '{DATASET_ID}'",
    ),
    output=OutputSpec(
        mode="parquet",
        outpath="downsampled_dataset.parquet",
        parquet_compression="zstd",
    ),
    tiledb_config={
        "vfs.s3.connect_timeout_ms": 60000,
        "vfs.s3.request_timeout_ms": 600000,
    },
)

outpath = run_query(spec_parquet)
print(f"Parquet file written to: {outpath}")

# Read it back to verify
import pyarrow.parquet as pq
pf = pq.read_table(str(outpath))
print(f"Read back: {pf.num_rows:,} rows x {pf.num_columns} columns")

---
## 11. Next Steps

From here you can:

- Switch to `mode="anndata"` to download expression data and feed it
  into a scanpy pipeline (QC, normalization, clustering, UMAP).
- Use the `dataset_list` output to select specific datasets for
  deeper analysis.
- Use `mode="parquet"` to stream large results directly to disk.
- Use `mode="arrow"` for in-memory columnar analysis with PyArrow.
- Filter to a single dataset with `extra_value_filter` for quick
  prototyping (as shown in Section 8).
- Save the DataFrame to CSV/Parquet for use in other tools:

```python
df.to_csv("metadata.csv", index=False)
df.to_parquet("metadata.parquet", index=False)
```