diff --git a/CHANGELOG.md b/CHANGELOG.md index 4391fcc..9c7f96b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 (e.g. `"genotype:HA1:158"` or `"genotype:HA1:158,189"`). Colors, category ordering, and the bottom-of-plot legend match the Nextstrain view of the same tree. +- `tree_color_scale` (default `None`): override the default coloring + with an explicit `{category: color}` mapping. Keys must match the + tree's categories one-to-one and the legend order follows the + user's key order. CLI form: `"value1=#hex1,value2=#hex2,..."`. +- `tree_color_legend_format` (default `None`): pass any subset of + Vega-Lite's + [Legend properties](https://vega.github.io/vega-lite/docs/legend.html#properties) + as a dict to style the tree's color legend (`orient`, `direction`, + `columns`, `padding`, `labelFontSize`, `titleFontSize`, …). When + `orient` is `"left"` or `"right"` and the user has not set + `columns` or `direction`, `columns=1` is forced so entries stack + vertically. CLI form: a JSON object string. +- `tree_color_legend_show` (default `True`): set to `False` to hide + the tree's color legend entirely while still coloring the tree. +- `scale_bar_font_size` (default `10`): font size for the tree's + scale bar label. ### Changed diff --git a/docs/examples.md b/docs/examples.md index c159454..f11f009 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -170,6 +170,98 @@ so renders three states (N, K, D): CLI flag: `--color-tree-by genotype:HA1:158`. In Python: `color_tree_by="genotype:HA1:158"`. +### Customizing the colors and legend + +Four optional knobs let you override the defaults: + +- `tree_color_scale` — supply your own `{category: color}` mapping + instead of the Auspice palette. Keys must match the tree's + categories one-to-one (extras or misses are an error). The legend + order follows the order of keys you pass. `"unknown"` is always + gray and must not be specified. +- `tree_color_legend_format` — pass any subset of Vega-Lite's + [Legend properties](https://vega.github.io/vega-lite/docs/legend.html#properties) + as a dict to style the legend (`orient`, `direction`, `columns`, + `padding`, `offset`, `labelFontSize`, `titleFontSize`, …). Smart + default: when `orient` is `"left"` or `"right"` and you have not + set `columns` or `direction`, entries stack vertically + (`columns=1` is forced). +- `tree_color_legend_show` — set to `False` to hide the legend + entirely while keeping the tree colored. +- `scale_bar_font_size` — font size for the scale bar's label. + +The exact category strings depend on the form of `color_tree_by`: + +| `color_tree_by` | Example category strings | +|---|---| +| node attribute (e.g. `"subclade"`) | `"K"`, `"J.2"`, `"J.2.4"` | +| `"genotype:HA1:158"` (single site) | `"K158"`, `"R158"`, `"E158"` | +| `"genotype:HA1:158,189"` (both sites vary) | `"K158/E189"`, `"R158/E189"` | +| `"genotype:HA1:158,189"` (189 invariant) | `"K158"`, `"R158"` (invariant site dropped) | + +If you pass keys that don't match, the error message lists the +tree's actual categories so you can copy them. + +The H3N2 example below puts all four knobs to work: an explicit +6-color palette (a colorblind-safe Okabe–Ito-inspired set) ordered +K → J.2.4 → J.2.3 → J.2.2 → J.2 → G.1.3.1, the legend moved to the +**left** of the combined plot at 14-pt, and a matching 14-pt +scale-bar label. + +![H3N2 combined chart with custom colors and legend](images/h3n2_combined_custom_colors.svg) + +[Open the interactive chart in a new tab →](charts/h3n2_combined_custom_colors.html){target="_blank"} + +```python +out = tree_annotated_plot.plot( + "examples/data/flu-seqneut-2025to2026_H3N2.json", + chart, + chart_strain_field="axis_label", + tree_strain_field="derived_haplotype", + branch_length="div", + tree_size=140, + scale_bar=True, + branch_length_units="substitutions", + color_tree_by="subclade", + tree_color_scale={ + "K": "#0072B2", + "J.2.4": "#009E73", + "J.2.3": "#D55E00", + "J.2.2": "#CC79A7", + "J.2": "#56B4E9", + "G.1.3.1": "#E69F00", + }, + tree_color_legend_format={ + "orient": "left", + "labelFontSize": 14, + "titleFontSize": 14, + }, + scale_bar_font_size=14, +) +``` + +`--tree-color-scale` on the CLI is a comma-separated list of +`key=color` pairs and `--tree-color-legend-format` is a JSON object +string (quote the whole argument in both cases so the shell doesn't +interpret `#`, braces, or quotes): + +```bash +tree-annotated-plot \ + --tree examples/data/flu-seqneut-2025to2026_H3N2.json \ + --chart examples/data/flu-seqneut-2025to2026_H3N2_titers.json \ + --chart-strain-field axis_label \ + --tree-strain-field derived_haplotype \ + --branch-length div \ + --tree-size 140 \ + --scale-bar \ + --branch-length-units substitutions \ + --color-tree-by subclade \ + --tree-color-scale "K=#0072B2,J.2.4=#009E73,J.2.3=#D55E00,J.2.2=#CC79A7,J.2=#56B4E9,G.1.3.1=#E69F00" \ + --tree-color-legend-format '{"orient":"left","labelFontSize":14,"titleFontSize":14}' \ + --scale-bar-font-size 14 \ + --output examples/data/h3n2_combined_custom_colors.html +``` + ### Reproduce — command line ```bash diff --git a/scripts/generate_docs_assets.py b/scripts/generate_docs_assets.py index 976ce75..1b9beec 100644 --- a/scripts/generate_docs_assets.py +++ b/scripts/generate_docs_assets.py @@ -183,6 +183,45 @@ def _render_kikawa() -> None: ) _save_pair(out, "h3n2_combined_genotype_158") + # H3N2 once more, demonstrating all four appearance-tuning knobs: + # an explicit `tree_color_scale` (Okabe-Ito-inspired colorblind-safe + # palette, ordered K, J.2.4, J.2.3, J.2.2, J.2, G.1.3.1), a 14-pt + # legend on the left, and a 14-pt scale-bar label. + h3n2_chart_custom = builder.make_chart( + subtype="H3N2", + chart_type="iqr", + titers=titers, + viruses=viruses, + metadata=metadata, + all_cohorts=all_cohorts, + ) + out = tree_annotated_plot.plot( + DATA_DIR / "flu-seqneut-2025to2026_H3N2.json", + h3n2_chart_custom, + chart_strain_field="axis_label", + tree_strain_field="derived_haplotype", + branch_length="div", + tree_size=140, + scale_bar=True, + branch_length_units="substitutions", + color_tree_by="subclade", + tree_color_scale={ + "K": "#0072B2", + "J.2.4": "#009E73", + "J.2.3": "#D55E00", + "J.2.2": "#CC79A7", + "J.2": "#56B4E9", + "G.1.3.1": "#E69F00", + }, + tree_color_legend_format={ + "orient": "left", + "labelFontSize": 14, + "titleFontSize": 14, + }, + scale_bar_font_size=14, + ) + _save_pair(out, "h3n2_combined_custom_colors") + def main() -> None: """Render every example to SVG + interactive HTML under `docs/`.""" diff --git a/src/tree_annotated_plot/_color.py b/src/tree_annotated_plot/_color.py index 87a681c..d46641a 100644 --- a/src/tree_annotated_plot/_color.py +++ b/src/tree_annotated_plot/_color.py @@ -119,6 +119,7 @@ def compute_node_color_values( root: TreeNode, color_spec: str, auspice_meta: dict | None = None, + tree_color_scale: dict[str, str] | None = None, ) -> ColorMapping: """Walk the tree and resolve per-node color categories + scale arrays. @@ -135,6 +136,11 @@ def compute_node_color_values( is available (caller passed a pre-built `TreeNode`). Used only to consult ``meta.colorings[].scale`` and ``.title`` for node-attr specs; ignored for genotype specs. + tree_color_scale + Optional user-supplied {category: color} mapping. When provided, its + keys must match the tree's real categories one-to-one (mismatch is a + ``ValueError``); the legend order is the dict's insertion order. + ``"unknown"`` is always gray and must not be specified. Returns ------- @@ -154,7 +160,10 @@ def compute_node_color_values( values_by_node = _color_by_genotype(root, gene, sites) categories = _ordered_categories(values_by_node.values()) - domain, range_ = _resolve_scale(categories, parsed, auspice_meta) + if tree_color_scale is not None: + domain, range_ = _apply_user_scale(categories, tree_color_scale) + else: + domain, range_ = _resolve_scale(categories, parsed, auspice_meta) legend_title = _resolve_legend_title(color_spec, parsed, auspice_meta) legend_values = _resolve_legend_values(domain, values_by_node, root) return ColorMapping( @@ -166,6 +175,54 @@ def compute_node_color_values( ) +def _apply_user_scale( + categories: list[str], + user_scale: dict[str, str], +) -> tuple[list[str], list[str]]: + """Build (domain, range_) from a user-supplied {category: color} dict. + + The user's key order is the legend order (Python dicts preserve insertion + order). Keys must match the tree's real categories one-to-one — extras or + misses raise ``ValueError`` listing the actual tree categories so the + user can copy-paste the correct keys (especially relevant for + genotype/haplotype categories like ``"K158"`` or ``"K158/E189"`` whose + exact form depends on the data). ``"unknown"`` is always gray and is + appended automatically when present in ``categories`` — the user must + not include it. + """ + real_categories = [c for c in categories if c != _UNKNOWN] + user_keys = list(user_scale.keys()) + + if _UNKNOWN in user_keys: + raise ValueError( + f"tree_color_scale must not contain key {_UNKNOWN!r}; " + "missing values are always rendered gray." + ) + + tree_set = set(real_categories) + user_set = set(user_keys) + missing = sorted(tree_set - user_set) + extra = sorted(user_set - tree_set) + if missing or extra: + msg = ( + "tree_color_scale keys don't match the tree's categories.\n" + f" Tree categories: {real_categories!r}\n" + f" Provided keys: {user_keys!r}" + ) + if missing: + msg += f"\n Missing from your scale: {missing!r}" + if extra: + msg += f"\n Unexpected in your scale: {extra!r}" + raise ValueError(msg) + + domain = list(user_keys) + range_ = [user_scale[k] for k in user_keys] + if _UNKNOWN in categories: + domain.append(_UNKNOWN) + range_.append(_GRAY) + return domain, range_ + + def _resolve_legend_values( domain: list[str], values_by_node: dict[str, str], diff --git a/src/tree_annotated_plot/_config.py b/src/tree_annotated_plot/_config.py index e827f41..75c1c98 100644 --- a/src/tree_annotated_plot/_config.py +++ b/src/tree_annotated_plot/_config.py @@ -18,7 +18,7 @@ import dataclasses import textwrap import typing -from typing import Annotated, Literal +from typing import Annotated, Any, Literal TreeLocation = Literal["left", "right", "top", "bottom"] @@ -165,6 +165,49 @@ class PlotConfig: "leaves the tree black.", ] = None + tree_color_scale: Annotated[ + dict[str, str] | None, + "Hardcoded color scale that overrides the default coloring. Keys " + "are category labels, values are colors (any Vega-Lite-compatible " + "string — e.g. hex codes). The legend order follows the order keys " + "appear here. The keys must match the tree's categories one-to-one " + '(extra or missing keys are an error). "unknown" is always gray ' + "and must not be specified. For genotype/haplotype colorings, the " + 'category strings include the site number (e.g. "K158" or ' + '"K158/E189"); a mismatch error lists the actual tree categories. ' + 'CLI form: "value1=#hex1,value2=#hex2,...".', + ] = None + + tree_color_legend_format: Annotated[ + dict[str, Any] | None, + "Vega-Lite Legend properties to apply to the tree-coloring legend. " + "Pass any subset of the keys at " + "https://vega.github.io/vega-lite/docs/legend.html#properties as a " + "Python dict (e.g. " + '`{"orient": "left", "labelFontSize": 13, "titleFontSize": 13}`). ' + 'Common keys: "orient" (default "bottom"), "direction", "columns", ' + '"padding", "offset", "labelFontSize", "titleFontSize". Smart ' + 'default: when "orient" is "left" or "right" and you have not set ' + '"columns" or "direction", "columns" is forced to 1 so entries ' + "stack vertically. " + "None (default) leaves Vega-Lite's defaults. Has no effect when " + "`color_tree_by` is None. CLI form: a JSON object string (quote the " + 'whole argument), e.g. \'{"orient":"left","labelFontSize":13}\'.', + ] = None + + tree_color_legend_show: Annotated[ + bool, + "Whether to render the tree-coloring legend. On (default) shows it. " + "Off hides the legend entirely while keeping the tree colored. Has " + "no effect when `color_tree_by` is None.", + ] = True + + scale_bar_font_size: Annotated[ + float, + "Font size (px) for the tree's scale bar label. Default 10. Has no " + "effect when `scale_bar` is off.", + ] = 10.0 + # Sidecar for Python-docstring-only prose, keyed by PlotConfig field name. # Empty by default — add an entry when a field's docstring entry needs more diff --git a/src/tree_annotated_plot/_plot.py b/src/tree_annotated_plot/_plot.py index 25b4a74..c39d894 100644 --- a/src/tree_annotated_plot/_plot.py +++ b/src/tree_annotated_plot/_plot.py @@ -17,6 +17,19 @@ from . import _color, _config, _tree from ._config import PlotConfig, TreeLocation +# Vega-Lite `orient` values that anchor the legend to the chart's left or +# right edge — these are the ones whose natural layout is one entry per row. +# (The Vega-Lite schema's other corner orients, `top-left`/`top-right`/ +# `bottom-left`/`bottom-right`, all anchor to the top or bottom edge, so +# they're horizontal-direction by default and don't need the smart default.) +# If a chart's config-level `legend.columns` default is set (a common +# pattern in altair theme/config setups), Vega-Lite will pack entries into +# multiple columns even with a left/right orient. We force columns=1 in +# that case so the user's choice of `orient: "left"` or `"right"` produces +# vertical stacking without requiring them to know about the `columns` +# interaction. +_VERTICAL_ORIENTS = frozenset({"left", "right"}) + # Accepted chart input forms for the public `plot` function. ChartInput = alt.TopLevelMixin | str | Path | dict @@ -47,6 +60,10 @@ def plot( strain_label_font_weight: Literal["normal", "bold"] = "normal", shift_tree_loc: int = 0, color_tree_by: str | None = None, + tree_color_scale: dict[str, str] | None = None, + tree_color_legend_format: dict[str, Any] | None = None, + tree_color_legend_show: bool = True, + scale_bar_font_size: float = 10.0, ) -> alt.HConcatChart | alt.VConcatChart: """Return an Altair chart with a phylogenetic tree drawn alongside `chart`.""" return _build( @@ -70,6 +87,10 @@ def plot( strain_label_font_weight=strain_label_font_weight, shift_tree_loc=shift_tree_loc, color_tree_by=color_tree_by, + tree_color_scale=tree_color_scale, + tree_color_legend_format=tree_color_legend_format, + tree_color_legend_show=tree_color_legend_show, + scale_bar_font_size=scale_bar_font_size, ), ) @@ -166,9 +187,17 @@ def _build( ) if config.color_tree_by is not None: color_mapping = _color.compute_node_color_values( - root, config.color_tree_by, auspice_meta=auspice_meta + root, + config.color_tree_by, + auspice_meta=auspice_meta, + tree_color_scale=config.tree_color_scale, ) else: + if config.tree_color_scale is not None: + raise ValueError( + "tree_color_scale was supplied but color_tree_by is None; " + "the override only applies when the tree is being colored." + ) color_mapping = None tree_chart = _build_tree_chart( root, @@ -181,6 +210,7 @@ def _build( tree_node_size=config.tree_node_size, leader_line_width=config.leader_line_width, scale_bar=config.scale_bar, + scale_bar_font_size=config.scale_bar_font_size, branch_length=config.branch_length, branch_length_units=config.branch_length_units, connect_leader_to_label=config.connect_leader_to_label, @@ -189,6 +219,8 @@ def _build( shift_tree_loc=config.shift_tree_loc, tip_names=tip_names, color_mapping=color_mapping, + legend_format=config.tree_color_legend_format, + legend_show=config.tree_color_legend_show, ) new_chart = _apply_tree_order_to_chart_object( @@ -1141,6 +1173,7 @@ def _build_scale_bar_layer( extra_units: float, strain_axis: str, label: str, + font_size: float = 10.0, ) -> alt.LayerChart: """Build a 2-layer (bar rule + text) chart for the scale bar. @@ -1182,7 +1215,7 @@ def _build_scale_bar_layer( ) text = ( alt.Chart(text_df) - .mark_text(fontSize=10, align="center", baseline="top") + .mark_text(fontSize=font_size, align="center", baseline="top") .encode(x="x:Q", y="y:Q", text="label:N") ) else: @@ -1196,7 +1229,7 @@ def _build_scale_bar_layer( ) text = ( alt.Chart(text_df) - .mark_text(fontSize=10, align="center", baseline="middle", angle=270) + .mark_text(fontSize=font_size, align="center", baseline="middle", angle=270) .encode(y="x:Q", x="y:Q", text="label:N") ) return bar + text @@ -1220,6 +1253,7 @@ def _build_tree_chart( tree_node_size: float = 45, leader_line_width: float = 1.0, scale_bar: bool = False, + scale_bar_font_size: float = 10.0, branch_length: str = "div", branch_length_units: str | None = None, connect_leader_to_label: bool = False, @@ -1228,6 +1262,8 @@ def _build_tree_chart( shift_tree_loc: int = 0, tip_names: list[str] | None = None, color_mapping: _color.ColorMapping | None = None, + legend_format: dict[str, Any] | None = None, + legend_show: bool = True, ) -> alt.Chart: """Build the tree panel. @@ -1276,22 +1312,44 @@ def _build_tree_chart( tips_df = tips_df.assign( color_value=tips_df["name"].map(color_mapping.values_by_node) ) + # Defaults the user can override by passing the same key in + # `legend_format`. Title comes from the color mapping's derived + # title (e.g. "subclade" or "HA1 site 158"); orient defaults to + # bottom to match the docs. legend_kwargs: dict = { "title": color_mapping.legend_title, "orient": "bottom", } + if legend_format is not None: + legend_kwargs.update(legend_format) if color_mapping.legend_values is not None: # Restrict the legend display without touching the scale, so # internal-node segments still render gray when "unknown" is on # the tree but no tip is. legend_kwargs["values"] = list(color_mapping.legend_values) + # Smart default for vertical stacking: when the (final) orient + # places the legend on the chart's left or right edge and the user + # has not explicitly set `columns` or `direction`, force columns=1. + # This counteracts a chart-level config default of `legend.columns` + # > 1 (some altair theme presets set this), which would otherwise + # pack a side-anchored legend into multiple columns. + if ( + legend_kwargs.get("orient") in _VERTICAL_ORIENTS + and "columns" not in legend_kwargs + and "direction" not in legend_kwargs + ): + legend_kwargs["columns"] = 1 + if not legend_show: + legend_arg: alt.Legend | None = None + else: + legend_arg = alt.Legend(**legend_kwargs) color_enc = alt.Color( "color_value:N", scale=alt.Scale( domain=list(color_mapping.domain), range=list(color_mapping.range_), ), - legend=alt.Legend(**legend_kwargs), + legend=legend_arg, ) else: color_enc = None @@ -1372,6 +1430,7 @@ def _build_tree_chart( extra_units=extra_units, strain_axis=strain_axis, label=bar_label, + font_size=scale_bar_font_size, ) else: extra_units = 0.0 diff --git a/src/tree_annotated_plot/cli.py b/src/tree_annotated_plot/cli.py index 866904e..b45c99c 100644 --- a/src/tree_annotated_plot/cli.py +++ b/src/tree_annotated_plot/cli.py @@ -13,6 +13,7 @@ from __future__ import annotations import dataclasses +import json import types import typing from pathlib import Path @@ -31,6 +32,100 @@ _SCALAR_CLICK_TYPE = {int: click.INT, float: click.FLOAT, str: click.STRING} +class _ColorScaleParamType(click.ParamType): + """Parse a `"key1=color1,key2=color2,..."` string into an ordered dict. + + Supports hex colors (e.g. ``"K=#416DCE,J.2=#59A3AA"``) — the user must + quote the whole argument so the shell doesn't interpret `#` as a + comment. An empty/blank value yields ``None``. + """ + + name = "color_scale" + + def convert(self, value, param, ctx): # type: ignore[override] + if value is None or isinstance(value, dict): + return value + text = str(value).strip() + if not text: + return None + result: dict[str, str] = {} + for piece in text.split(","): + piece = piece.strip() + if not piece: + continue + if "=" not in piece: + self.fail( + f"expected 'key=color' pairs separated by commas; " + f"got {piece!r}.", + param, + ctx, + ) + key, color = piece.split("=", 1) + key = key.strip() + color = color.strip() + if not key or not color: + self.fail(f"empty key or color in {piece!r}.", param, ctx) + if key in result: + self.fail(f"duplicate key {key!r} in tree_color_scale.", param, ctx) + result[key] = color + return result or None + + +_COLOR_SCALE_PARAM_TYPE = _ColorScaleParamType() + + +class _JsonDictParamType(click.ParamType): + """Parse a JSON-object string into a dict. + + Used for ``--tree-color-legend-format``: any subset of Vega-Lite's + Legend properties as a JSON object (e.g. + ``'{"orient":"left","labelFontSize":13}'``). The user must quote the + whole argument so the shell doesn't interpret braces or quotes. + """ + + name = "json_dict" + + def convert(self, value, param, ctx): # type: ignore[override] + if value is None or isinstance(value, dict): + return value + text = str(value).strip() + if not text: + return None + try: + parsed = json.loads(text) + except json.JSONDecodeError as exc: + self.fail(f"invalid JSON: {exc}", param, ctx) + if not isinstance(parsed, dict): + self.fail( + f'expected a JSON object (e.g. \'{{"orient":"left"}}\'); ' + f"got {type(parsed).__name__}.", + param, + ctx, + ) + return parsed + + +_JSON_DICT_PARAM_TYPE = _JsonDictParamType() + + +def _is_str_dict(tp: Any) -> bool: + """Return True for `dict[str, str]` only (not `dict[str, Any]`).""" + if get_origin(tp) is dict: + args = get_args(tp) + return args == (str, str) + return False + + +def _is_dict_any(tp: Any) -> bool: + """Return True for `dict[str, Any]` or a bare `dict`.""" + if tp is dict: + return True + if get_origin(tp) is dict: + args = get_args(tp) + return args == (str, Any) or args == () + return False + + def _option_for_field(field: dataclasses.Field, hints: dict) -> Any: """Build a `click.option(...)` decorator from a PlotConfig field. @@ -73,12 +168,48 @@ def _option_for_field(field: dataclasses.Field, hints: dict) -> Any: kwargs["required"] = True return click.option(cli_name, **kwargs) + # dict[str, str] (with or without `| None`) → custom comma-separated parser. + if _is_str_dict(real_type): + return click.option( + cli_name, + type=_COLOR_SCALE_PARAM_TYPE, + default=field.default if has_default else None, + help=description, + show_default=False, + ) + + # dict[str, Any] (with or without `| None`) → JSON-object parser. + if _is_dict_any(real_type): + return click.option( + cli_name, + type=_JSON_DICT_PARAM_TYPE, + default=field.default if has_default else None, + help=description, + show_default=False, + ) + # Optional / Union — unwrap to the non-None branch. origin = get_origin(real_type) if origin in (typing.Union, types.UnionType): non_none = [a for a in get_args(real_type) if a is not type(None)] if len(non_none) == 1: inner = non_none[0] + if _is_str_dict(inner): + return click.option( + cli_name, + type=_COLOR_SCALE_PARAM_TYPE, + default=field.default if has_default else None, + help=description, + show_default=False, + ) + if _is_dict_any(inner): + return click.option( + cli_name, + type=_JSON_DICT_PARAM_TYPE, + default=field.default if has_default else None, + help=description, + show_default=False, + ) if get_origin(inner) is Literal: choices = [str(c) for c in get_args(inner)] return click.option( diff --git a/tests/test_appearance_tuning.py b/tests/test_appearance_tuning.py new file mode 100644 index 0000000..7ee5444 --- /dev/null +++ b/tests/test_appearance_tuning.py @@ -0,0 +1,797 @@ +"""Tests for `tree_color_scale`, `tree_color_legend_format`, +`tree_color_legend_show`, and `scale_bar_font_size`.""" + +from __future__ import annotations + +import json +from typing import Any + +import altair as alt +import pandas as pd +import pytest + +import tree_annotated_plot +from tree_annotated_plot import _color, _tree + +# ----------------------------------------------------------------------------- +# Test fixtures (inlined rather than imported across test modules so this file +# stays runnable on its own — `tests/` has no __init__.py). +# ----------------------------------------------------------------------------- + + +def _attr_auspice() -> dict: + """Tiny tree with `subclade` on every node and tips A..D in two clades.""" + return { + "version": "v2", + "meta": {}, + "tree": { + "name": "ROOT", + "node_attrs": {"div": 0.0, "subclade": {"value": "X"}}, + "children": [ + { + "name": "INT_LEFT", + "node_attrs": {"div": 0.02, "subclade": {"value": "X"}}, + "children": [ + { + "name": "A", + "node_attrs": { + "div": 0.04, + "subclade": {"value": "X"}, + }, + }, + { + "name": "B", + "node_attrs": { + "div": 0.05, + "subclade": {"value": "X"}, + }, + }, + ], + }, + { + "name": "INT_RIGHT", + "node_attrs": {"div": 0.03, "subclade": {"value": "Y"}}, + "children": [ + { + "name": "C", + "node_attrs": { + "div": 0.06, + "subclade": {"value": "Y"}, + }, + }, + { + "name": "D", + "node_attrs": { + "div": 0.07, + "subclade": {"value": "Z"}, + }, + }, + ], + }, + ], + }, + } + + +def _genotype_auspice() -> dict: + return { + "version": "v2", + "meta": {}, + "tree": { + "name": "ROOT", + "node_attrs": {"div": 0.0}, + "children": [ + { + "name": "tip_A", + "node_attrs": {"div": 0.04}, + "branch_attrs": {"mutations": {"HA1": ["N158K"]}}, + }, + {"name": "tip_B", "node_attrs": {"div": 0.05}}, + { + "name": "INT1", + "node_attrs": {"div": 0.02}, + "branch_attrs": {"mutations": {"HA1": ["N158D"]}}, + "children": [ + {"name": "tip_C", "node_attrs": {"div": 0.06}}, + {"name": "tip_D", "node_attrs": {"div": 0.07}}, + ], + }, + ], + }, + } + + +def _haplotype_auspice() -> dict: + return { + "version": "v2", + "meta": {}, + "tree": { + "name": "ROOT", + "node_attrs": {"div": 0.0}, + "children": [ + { + "name": "tip_A", + "node_attrs": {"div": 0.04}, + "branch_attrs": {"mutations": {"HA1": ["N158K"]}}, + }, + { + "name": "tip_B", + "node_attrs": {"div": 0.05}, + "branch_attrs": {"mutations": {"HA1": ["S189T"]}}, + }, + { + "name": "INT1", + "node_attrs": {"div": 0.02}, + "branch_attrs": {"mutations": {"HA1": ["N158K", "S189T"]}}, + "children": [ + {"name": "tip_C", "node_attrs": {"div": 0.06}}, + {"name": "tip_D", "node_attrs": {"div": 0.07}}, + ], + }, + ], + }, + } + + +def _load(d: dict) -> _tree.TreeNode: + return _tree.load_auspice(d, tree_strain_field="name", branch_length="div") + + +def _vertical_chart(strains: list[str]) -> alt.Chart: + df = pd.DataFrame({"strain": strains, "titer": [1.0, 2.0, 4.0, 8.0]}) + return ( + alt.Chart(df) + .mark_circle() + .encode(x="titer:Q", y=alt.Y("strain:N")) + .properties(width=200, height=200) + ) + + +def _kw(): + return dict( + chart_strain_field="strain", tree_strain_field="name", branch_length="div" + ) + + +def _find_color_encodings(node: Any) -> list[tuple[str, dict]]: + hits: list[tuple[str, dict]] = [] + + def walk(o: Any, path: str) -> None: + if isinstance(o, dict): + enc = o.get("encoding") + if isinstance(enc, dict) and "color" in enc: + hits.append((path, enc["color"])) + for k, v in o.items(): + walk(v, f"{path}.{k}") + elif isinstance(o, list): + for i, v in enumerate(o): + walk(v, f"{path}[{i}]") + + walk(node, "") + return hits + + +def _tree_panel_color_encodings(out) -> list[dict]: + spec = out.to_dict() + panels = spec.get("hconcat") or spec.get("vconcat") or [] + assert panels + tree_panel = panels[0] + return [ + enc + for _, enc in _find_color_encodings(tree_panel) + if isinstance(enc, dict) and enc.get("field") == "color_value" + ] + + +def _run_cli(args, expect_success: bool = True): + from click.testing import CliRunner + + from tree_annotated_plot import cli as cli_module + + runner = CliRunner() + result = runner.invoke(cli_module.main, args, catch_exceptions=False) + if expect_success and result.exit_code != 0: + raise AssertionError(f"CLI exit {result.exit_code}\n{result.output}") + return result + + +def _cli_setup(tmp_path, tree_dict: dict, chart: alt.Chart): + tree_path = tmp_path / "tree.json" + chart_path = tmp_path / "chart.json" + out_path = tmp_path / "out.json" + tree_path.write_text(json.dumps(tree_dict)) + chart.save(str(chart_path)) + return tree_path, chart_path, out_path + + +# ----------------------------------------------------------------------------- +# tree_color_scale: validation +# ----------------------------------------------------------------------------- + + +def test_tree_color_scale_overrides_palette_in_user_order(): + root = _load(_attr_auspice()) + user_scale = {"Z": "#111111", "Y": "#222222", "X": "#333333"} + m = _color.compute_node_color_values(root, "subclade", tree_color_scale=user_scale) + # Domain order is the dict's insertion order — not descending frequency. + assert m.domain == ["Z", "Y", "X"] + assert m.range_ == ["#111111", "#222222", "#333333"] + + +def test_tree_color_scale_appends_unknown_when_present(): + d = _attr_auspice() + # Strip subclade off one tip so "unknown" is a real category. + d["tree"]["children"][0]["children"][0]["node_attrs"].pop("subclade") + root = _load(d) + m = _color.compute_node_color_values( + root, + "subclade", + tree_color_scale={"X": "#aaaaaa", "Y": "#bbbbbb", "Z": "#cccccc"}, + ) + assert m.domain[-1] == "unknown" + assert m.range_[-1] == "#888888" + # User keys come before "unknown" in the domain. + assert m.domain[:-1] == ["X", "Y", "Z"] + + +def test_tree_color_scale_missing_key_raises_lists_categories(): + root = _load(_attr_auspice()) + with pytest.raises(ValueError) as exc: + _color.compute_node_color_values( + root, + "subclade", + tree_color_scale={"X": "#aaaaaa", "Y": "#bbbbbb"}, # missing Z + ) + msg = str(exc.value) + assert "Tree categories" in msg + assert "'X'" in msg and "'Y'" in msg and "'Z'" in msg + assert "Missing from your scale" in msg + assert "'Z'" in msg + + +def test_tree_color_scale_extra_key_raises(): + root = _load(_attr_auspice()) + with pytest.raises(ValueError) as exc: + _color.compute_node_color_values( + root, + "subclade", + tree_color_scale={ + "X": "#aaaaaa", + "Y": "#bbbbbb", + "Z": "#cccccc", + "W": "#dddddd", # not in tree + }, + ) + msg = str(exc.value) + assert "Unexpected in your scale" in msg + assert "'W'" in msg + + +def test_tree_color_scale_unknown_key_rejected(): + root = _load(_attr_auspice()) + with pytest.raises(ValueError, match="must not contain"): + _color.compute_node_color_values( + root, + "subclade", + tree_color_scale={ + "X": "#aaaaaa", + "Y": "#bbbbbb", + "Z": "#cccccc", + "unknown": "#999999", + }, + ) + + +def test_tree_color_scale_genotype_categories_use_letter_site(): + """User-supplied scale for a single-site genotype must use `` + keys (e.g. "K158"), not bare letters. Mismatch error names the actual + tree categories so the user can copy them.""" + root = _load(_genotype_auspice()) + with pytest.raises(ValueError) as exc: + _color.compute_node_color_values( + root, + "genotype:HA1:158", + tree_color_scale={"K": "#aaa", "N": "#bbb", "D": "#ccc"}, + ) + msg = str(exc.value) + # The actual tree categories include the site number. + assert "'K158'" in msg or "'N158'" in msg or "'D158'" in msg + + +def test_tree_color_scale_haplotype_categories_use_slash_join(): + root = _load(_haplotype_auspice()) + # Real categories are slash-joined like "K158/T189", "N158/T189", etc. + m_default = _color.compute_node_color_values(root, "genotype:HA1:158,189") + real_cats = [c for c in m_default.domain if c != "unknown"] + # All real cats should contain "/" given two varying sites. + assert all("/" in c for c in real_cats), real_cats + user_scale = {c: "#000000" for c in real_cats} + m = _color.compute_node_color_values( + root, "genotype:HA1:158,189", tree_color_scale=user_scale + ) + # Order matches the dict's insertion order. + assert m.domain[: len(real_cats)] == list(user_scale.keys()) + + +# ----------------------------------------------------------------------------- +# tree_color_scale: end-to-end via plot() +# ----------------------------------------------------------------------------- + + +def test_plot_tree_color_scale_propagates_to_spec(): + out = tree_annotated_plot.plot( + _attr_auspice(), + _vertical_chart(["A", "B", "C", "D"]), + **_kw(), + color_tree_by="subclade", + tree_color_scale={"Z": "#111111", "Y": "#222222", "X": "#333333"}, + ) + encs = _tree_panel_color_encodings(out) + assert encs + scale = encs[0]["scale"] + assert scale["domain"] == ["Z", "Y", "X"] + assert scale["range"] == ["#111111", "#222222", "#333333"] + + +def test_plot_tree_color_scale_without_color_tree_by_raises(): + with pytest.raises(ValueError, match="color_tree_by is None"): + tree_annotated_plot.plot( + _attr_auspice(), + _vertical_chart(["A", "B", "C", "D"]), + **_kw(), + tree_color_scale={"X": "#aaa", "Y": "#bbb", "Z": "#ccc"}, + ) + + +# ----------------------------------------------------------------------------- +# tree_color_legend_format +# ----------------------------------------------------------------------------- + + +def test_legend_format_default_is_bottom_orient_no_overrides(): + out = tree_annotated_plot.plot( + _attr_auspice(), + _vertical_chart(["A", "B", "C", "D"]), + **_kw(), + color_tree_by="subclade", + ) + encs = _tree_panel_color_encodings(out) + legend = encs[0]["legend"] + assert legend["orient"] == "bottom" + assert "titleFontSize" not in legend + assert "labelFontSize" not in legend + # No smart-default columns when orient is bottom. + assert "columns" not in legend + + +def test_legend_format_font_sizes_propagate(): + out = tree_annotated_plot.plot( + _attr_auspice(), + _vertical_chart(["A", "B", "C", "D"]), + **_kw(), + color_tree_by="subclade", + tree_color_legend_format={"labelFontSize": 13, "titleFontSize": 13}, + ) + encs = _tree_panel_color_encodings(out) + for enc in encs: + legend = enc["legend"] + assert legend["labelFontSize"] == 13 + assert legend["titleFontSize"] == 13 + # Default orient stays bottom because we didn't override it. + assert legend["orient"] == "bottom" + + +def test_legend_format_orient_overrides_default(): + out = tree_annotated_plot.plot( + _attr_auspice(), + _vertical_chart(["A", "B", "C", "D"]), + **_kw(), + color_tree_by="subclade", + tree_color_legend_format={"orient": "right"}, + ) + encs = _tree_panel_color_encodings(out) + assert encs[0]["legend"]["orient"] == "right" + + +@pytest.mark.parametrize("orient", ["left", "right"]) +def test_legend_format_smart_default_columns_for_side_orients(orient): + out = tree_annotated_plot.plot( + _attr_auspice(), + _vertical_chart(["A", "B", "C", "D"]), + **_kw(), + color_tree_by="subclade", + tree_color_legend_format={"orient": orient}, + ) + legend = _tree_panel_color_encodings(out)[0]["legend"] + assert legend["orient"] == orient + assert legend["columns"] == 1 + + +def test_legend_format_user_columns_beats_smart_default(): + out = tree_annotated_plot.plot( + _attr_auspice(), + _vertical_chart(["A", "B", "C", "D"]), + **_kw(), + color_tree_by="subclade", + tree_color_legend_format={"orient": "left", "columns": 3}, + ) + legend = _tree_panel_color_encodings(out)[0]["legend"] + assert legend["columns"] == 3 + + +def test_legend_format_user_direction_disables_smart_default(): + out = tree_annotated_plot.plot( + _attr_auspice(), + _vertical_chart(["A", "B", "C", "D"]), + **_kw(), + color_tree_by="subclade", + tree_color_legend_format={"orient": "left", "direction": "horizontal"}, + ) + legend = _tree_panel_color_encodings(out)[0]["legend"] + assert legend["direction"] == "horizontal" + assert "columns" not in legend + + +@pytest.mark.parametrize("orient", ["top", "bottom", "top-left", "bottom-right"]) +def test_legend_format_smart_default_skipped_for_top_bottom_orients(orient): + out = tree_annotated_plot.plot( + _attr_auspice(), + _vertical_chart(["A", "B", "C", "D"]), + **_kw(), + color_tree_by="subclade", + tree_color_legend_format={"orient": orient}, + ) + legend = _tree_panel_color_encodings(out)[0]["legend"] + assert legend["orient"] == orient + # Top/bottom anchors should not get the vertical-stack smart default. + assert "columns" not in legend + + +# ----------------------------------------------------------------------------- +# tree_color_legend_show +# ----------------------------------------------------------------------------- + + +def test_legend_show_default_true(): + out = tree_annotated_plot.plot( + _attr_auspice(), + _vertical_chart(["A", "B", "C", "D"]), + **_kw(), + color_tree_by="subclade", + ) + encs = _tree_panel_color_encodings(out) + # Legend object is present. + assert encs[0].get("legend") + + +def test_legend_show_false_hides_legend(): + out = tree_annotated_plot.plot( + _attr_auspice(), + _vertical_chart(["A", "B", "C", "D"]), + **_kw(), + color_tree_by="subclade", + tree_color_legend_show=False, + ) + encs = _tree_panel_color_encodings(out) + # Color encodings are still present (tree is colored), but `legend` on + # each is null/missing — Altair drops the property when legend=None. + assert encs + for enc in encs: + assert enc.get("legend") in (None, {}) or enc.get("legend") is None + + +# ----------------------------------------------------------------------------- +# scale_bar_font_size +# ----------------------------------------------------------------------------- + + +def _scale_bar_text_marks(out) -> list[dict]: + """Return every `mark_text` block on the tree panel — the scale-bar text + is the only text mark when connect_leader_to_label is off.""" + spec = out.to_dict() + panels = spec.get("hconcat") or spec.get("vconcat") or [] + tree_panel = panels[0] + hits: list[dict] = [] + + def walk(o: Any) -> None: + if isinstance(o, dict): + mark = o.get("mark") + if isinstance(mark, dict) and mark.get("type") == "text": + hits.append(mark) + elif mark == "text": + hits.append({"type": "text"}) + for v in o.values(): + walk(v) + elif isinstance(o, list): + for v in o: + walk(v) + + walk(tree_panel) + return hits + + +def test_scale_bar_font_size_default_is_10(): + out = tree_annotated_plot.plot( + _attr_auspice(), + _vertical_chart(["A", "B", "C", "D"]), + **_kw(), + scale_bar=True, + ) + text_marks = _scale_bar_text_marks(out) + # At least one text mark with fontSize=10 (the scale-bar label). + sizes = [m.get("fontSize") for m in text_marks if "fontSize" in m] + assert 10 in sizes or 10.0 in sizes + + +def test_scale_bar_font_size_propagates(): + out = tree_annotated_plot.plot( + _attr_auspice(), + _vertical_chart(["A", "B", "C", "D"]), + **_kw(), + scale_bar=True, + scale_bar_font_size=18.0, + ) + text_marks = _scale_bar_text_marks(out) + sizes = [m.get("fontSize") for m in text_marks if "fontSize" in m] + assert 18.0 in sizes or 18 in sizes + + +# ----------------------------------------------------------------------------- +# CLI +# ----------------------------------------------------------------------------- + + +def test_cli_tree_color_scale(tmp_path): + tree_path, chart_path, out_path = _cli_setup( + tmp_path, _attr_auspice(), _vertical_chart(["A", "B", "C", "D"]) + ) + _run_cli( + [ + "--tree", + str(tree_path), + "--chart", + str(chart_path), + "--output", + str(out_path), + "--chart-strain-field", + "strain", + "--tree-strain-field", + "name", + "--branch-length", + "div", + "--color-tree-by", + "subclade", + "--tree-color-scale", + "Z=#111111,Y=#222222,X=#333333", + ] + ) + assert out_path.exists() + spec = json.loads(out_path.read_text()) + encs = [ + enc + for _, enc in _find_color_encodings(spec.get("hconcat", [{}])[0]) + if isinstance(enc, dict) and enc.get("field") == "color_value" + ] + assert encs + scale = encs[0]["scale"] + assert scale["domain"] == ["Z", "Y", "X"] + assert scale["range"] == ["#111111", "#222222", "#333333"] + + +def test_cli_tree_color_scale_invalid_format(tmp_path): + tree_path, chart_path, out_path = _cli_setup( + tmp_path, _attr_auspice(), _vertical_chart(["A", "B", "C", "D"]) + ) + from click.testing import CliRunner + + from tree_annotated_plot import cli as cli_module + + runner = CliRunner() + result = runner.invoke( + cli_module.main, + [ + "--tree", + str(tree_path), + "--chart", + str(chart_path), + "--output", + str(out_path), + "--chart-strain-field", + "strain", + "--tree-strain-field", + "name", + "--branch-length", + "div", + "--color-tree-by", + "subclade", + "--tree-color-scale", + "Z#111111", # missing '=' + ], + ) + assert result.exit_code != 0 + assert "key=color" in (result.output or "") or "key=color" in ( + str(result.exception) if result.exception else "" + ) + + +def test_cli_legend_format_json(tmp_path): + tree_path, chart_path, out_path = _cli_setup( + tmp_path, _attr_auspice(), _vertical_chart(["A", "B", "C", "D"]) + ) + _run_cli( + [ + "--tree", + str(tree_path), + "--chart", + str(chart_path), + "--output", + str(out_path), + "--chart-strain-field", + "strain", + "--tree-strain-field", + "name", + "--branch-length", + "div", + "--color-tree-by", + "subclade", + "--tree-color-legend-format", + '{"orient": "right", "labelFontSize": 14, "titleFontSize": 14}', + ] + ) + spec = json.loads(out_path.read_text()) + encs = [ + enc + for _, enc in _find_color_encodings(spec.get("hconcat", [{}])[0]) + if isinstance(enc, dict) and enc.get("field") == "color_value" + ] + assert encs + legend = encs[0]["legend"] + assert legend["orient"] == "right" + assert legend["titleFontSize"] == 14 + assert legend["labelFontSize"] == 14 + # Smart default fired: side-anchored orient + no user columns/direction. + assert legend["columns"] == 1 + + +def test_cli_legend_format_invalid_json_errors(tmp_path): + tree_path, chart_path, out_path = _cli_setup( + tmp_path, _attr_auspice(), _vertical_chart(["A", "B", "C", "D"]) + ) + from click.testing import CliRunner + + from tree_annotated_plot import cli as cli_module + + runner = CliRunner() + result = runner.invoke( + cli_module.main, + [ + "--tree", + str(tree_path), + "--chart", + str(chart_path), + "--output", + str(out_path), + "--chart-strain-field", + "strain", + "--tree-strain-field", + "name", + "--branch-length", + "div", + "--color-tree-by", + "subclade", + "--tree-color-legend-format", + "{not valid json", + ], + ) + assert result.exit_code != 0 + + +def test_cli_legend_format_non_object_errors(tmp_path): + tree_path, chart_path, out_path = _cli_setup( + tmp_path, _attr_auspice(), _vertical_chart(["A", "B", "C", "D"]) + ) + from click.testing import CliRunner + + from tree_annotated_plot import cli as cli_module + + runner = CliRunner() + result = runner.invoke( + cli_module.main, + [ + "--tree", + str(tree_path), + "--chart", + str(chart_path), + "--output", + str(out_path), + "--chart-strain-field", + "strain", + "--tree-strain-field", + "name", + "--branch-length", + "div", + "--color-tree-by", + "subclade", + "--tree-color-legend-format", + '["not", "an", "object"]', + ], + ) + assert result.exit_code != 0 + + +def test_cli_no_tree_color_legend_show_hides_legend(tmp_path): + tree_path, chart_path, out_path = _cli_setup( + tmp_path, _attr_auspice(), _vertical_chart(["A", "B", "C", "D"]) + ) + _run_cli( + [ + "--tree", + str(tree_path), + "--chart", + str(chart_path), + "--output", + str(out_path), + "--chart-strain-field", + "strain", + "--tree-strain-field", + "name", + "--branch-length", + "div", + "--color-tree-by", + "subclade", + "--no-tree-color-legend-show", + ] + ) + spec = json.loads(out_path.read_text()) + encs = [ + enc + for _, enc in _find_color_encodings(spec.get("hconcat", [{}])[0]) + if isinstance(enc, dict) and enc.get("field") == "color_value" + ] + assert encs + for enc in encs: + assert enc.get("legend") in (None, {}) or enc.get("legend") is None + + +def test_cli_scale_bar_font_size(tmp_path): + tree_path, chart_path, out_path = _cli_setup( + tmp_path, _attr_auspice(), _vertical_chart(["A", "B", "C", "D"]) + ) + _run_cli( + [ + "--tree", + str(tree_path), + "--chart", + str(chart_path), + "--output", + str(out_path), + "--chart-strain-field", + "strain", + "--tree-strain-field", + "name", + "--branch-length", + "div", + "--scale-bar", + "--scale-bar-font-size", + "16", + ] + ) + spec = json.loads(out_path.read_text()) + panels = spec.get("hconcat", []) + tree_panel = panels[0] if panels else {} + + def collect_text_marks(o: Any, hits: list[dict]) -> None: + if isinstance(o, dict): + mark = o.get("mark") + if isinstance(mark, dict) and mark.get("type") == "text": + hits.append(mark) + for v in o.values(): + collect_text_marks(v, hits) + elif isinstance(o, list): + for v in o: + collect_text_marks(v, hits) + + text_marks: list[dict] = [] + collect_text_marks(tree_panel, text_marks) + sizes = [m.get("fontSize") for m in text_marks if "fontSize" in m] + assert 16.0 in sizes or 16 in sizes