# CosmoTree (Gtol)

A single-notebook pipeline to render large weighted phylogenetic tree with Cosmograph. The notebook:
- Parses Newick into a lightweight typed tree
- Computes cumulative branch-length X and equal-spaced leaf Y
- Lays out an orthogonal tree with elbows and non-overlapping stems
- Renders via `cosmograph_widget.Cosmograph` using scalable per-kind sizes

Performance rationale: Cosmograph handles tens of thousands of nodes interactively; matplotlib/similar packages are too slow for this scale.


In [None]:
# Environment setup
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Iterable
import re
import sys
import os
import pandas as pd


def _log(*args: object) -> None:
    print(*args, file=sys.stderr)

In [None]:
# Configuration constants
# Geometry & appearance
X_SCALE_PX: float = 140.0  # px per branch-length unit
MIN_STEM_GAP_PX: float = 56.0  # min horizontal gap between adjacent vertical stems
PARENT_STUB_PX: float = 20.0  # elbow stub length before vertical
WEIGHTED_STUB_PX: float = 40.0  # minimal horizontal stub to weighted segment (set to 0 to preserve the ratio)
LEAF_Y_STEP_PX: float = 400.0  # vertical spacing between consecutive leaves
TIP_PAD_PX: float = 40.0  # extra space right of farthest leaf for markers

# Per-kind point sizes (in pixels for direct strategy)
SIZE_LEAF_MARKER: float = 20.0
SIZE_INTERNAL: float = 6.0
SIZE_BEND: float = 3.0
SIZE_LEAF_REAL: float = 8.0

# Global size scaling factor (applied before Cosmograph)
NODE_SIZE_SCALE: float = 2.0

# Colors
COLOR_LEAF: str = "#f5d76e"  # yellow
COLOR_INTERNAL: str = "#8ab4f8"  # light blue
COLOR_BEND: str = "#9aa0a6"  # gray
COLOR_LINK: str = "#97A1A9"  # gray
LINK_WIDTH_PX: float = 0.7

## Newick parsing

We tokenize Newick with a small (not really) regex, then build a typed node map using a stack:
- On `(` create an internal node and descend; on `)` assign pending attributes and ascend
- Leaf tokens become nodes attached to current parent
- `:length` applies to the last emitted node or is stored pending for the just-closed group
- If multiple roots appear, we synthesize a single `root0`

Complexity is linear in input length. Names and lengths are preserved; branch lengths < 0 are clamped later during layout.


In [None]:
# Newick parsing — code

@dataclass
class TNode:
    """A tree node parsed from Newick.

    Attributes:
        id: Stable synthetic identifier.
        name: Label (for leaves typically), may be empty.
        parent: Parent node id or None for root.
        blen: Branch length from parent to this node (non-negative real expected).
        children: Child node ids in insertion order.
    """
    id: str
    name: str = ""
    parent: Optional[str] = None
    blen: float = 0.0
    children: List[str] = field(default_factory=list)

_TOKEN_RE = re.compile(r"\s*([(),;])\s*|\s*([^(),:;]+)\s*|(\s*:\s*[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)")  # help me


def _tokenize(newick: str) -> Iterable[str]:
    """Yield tokens from Newick: parens, commas, semicolon, names, and ':len'."""
    i = 0
    s = newick.strip()
    while i < len(s):
        m = _TOKEN_RE.match(s, i)
        if not m:
            if s[i].isspace():
                i += 1
                continue
            raise ValueError(f"[Newick] Unexpected token near: {s[i:i+20]!r}")
        tok = m.group(1) or m.group(2) or m.group(3)
        if tok is not None:
            yield tok.strip()
        i = m.end()


def parse_newick(newick: str) -> Dict[str, TNode]:
    """Parse Newick text into a node-id -> TNode map.

    Rules:
    - Internal groups create anonymous nodes; leaves are named tokens.
    - ':len' attaches to the last emitted node or the just-closed group.
    - Multiple top-level groups are unified under synthetic 'root0'.
    """
    nodes: Dict[str, TNode] = {}
    stack: List[Optional[str]] = []
    last: Optional[str] = None
    nid = 0

    def new_id() -> str:
        nonlocal nid
        nid += 1
        return f"n{nid}"

    current_parent: Optional[str] = None
    pending_name: Optional[str] = None
    pending_len: Optional[float] = None

    for tok in _tokenize(newick):
        if tok == "(":
            u = new_id()
            nodes[u] = TNode(id=u)
            if current_parent is not None:
                nodes[current_parent].children.append(u)
                nodes[u].parent = current_parent
            stack.append(current_parent)
            current_parent = u
            last = None
        elif tok == ",":
            last = None
            pending_name = None
            pending_len = None
        elif tok == ")":
            if pending_name is not None:
                nodes[current_parent].name = pending_name
                pending_name = None
            if pending_len is not None:
                nodes[current_parent].blen = float(pending_len)
                pending_len = None
            current_parent = stack.pop()
            last = None
        elif tok == ";":
            break
        elif tok.startswith(":"):
            L = float(tok[1:].strip())
            if last is None:
                pending_len = L
            else:
                nodes[last].blen = L
        else:
            u = new_id()
            nodes[u] = TNode(id=u, name=tok, parent=current_parent)
            if current_parent is not None:
                nodes[current_parent].children.append(u)
            last = u
            pending_name = None
            pending_len = None

    roots = [k for k, v in nodes.items() if v.parent is None]
    if not roots:
        raise ValueError("[Parse] No root detected.")
    if len(roots) == 1:
        root_id = roots[0]
    else:
        root_id = "root0"
        nodes[root_id] = TNode(id=root_id, name="root", parent=None, blen=0.0, children=roots)
        for r in roots:
            nodes[r].parent = root_id

    _log(f"[Parse] nodes={len(nodes):,} leaves={sum(1 for v in nodes.values() if not v.children):,} root={root_id}")
    return nodes


## Tree utilities

We derive:
- Root detection by `parent is None`
- Leaf collection via DFS
- Child ordering to avoid crossings by sorting children by their minimal leaf name
- Cumulative branch length distances (X, later scaled)
- Equal-spaced leaf Y, then postorder parent Y as mean of children

All operations are linear in node count.


In [12]:
# Tree utilities — code
from typing import Set


def _collect_leaves(nodes: Dict[str, TNode], u: str) -> List[str]:
    """Return a list of leaf node ids under `u` (inclusive if `u` is a leaf)."""
    if not nodes[u].children:
        return [u]
    acc: List[str] = []
    for c in nodes[u].children:
        acc.extend(_collect_leaves(nodes, c))
    return acc


def _find_root(nodes: Dict[str, TNode]) -> str:
    """Return the single root id (node with `parent is None`)."""
    for k, v in nodes.items():
        if v.parent is None:
            return k
    raise ValueError("No root")


def _sort_children_for_no_crossing(nodes: Dict[str, TNode]) -> None:
    """Sort children of each internal node by the minimal leaf name to reduce crossings."""
    def min_leaf_name(u: str) -> str:
        names = [nodes[x].name or x for x in _collect_leaves(nodes, u)]
        return min(names)
    for u, v in nodes.items():
        if v.children:
            v.children.sort(key=min_leaf_name)


def compute_cumdist(nodes: Dict[str, TNode], root: Optional[str] = None) -> Dict[str, float]:
    """Compute cumulative branch length distance from `root` to each node."""
    if root is None:
        root = _find_root(nodes)
    dist: Dict[str, float] = {root: 0.0}
    stack: List[str] = [root]
    while stack:
        u = stack.pop()
        for c in nodes[u].children:
            dist[c] = dist[u] + max(0.0, float(nodes[c].blen))
            stack.append(c)
    return dist


def assign_y_equal_leaf_spacing(nodes: Dict[str, TNode], leaf_step: float) -> Dict[str, float]:
    """Assign Y such that leaves are equally spaced by `leaf_step`, parents at child mean."""
    root = _find_root(nodes)
    _sort_children_for_no_crossing(nodes)
    leaves = _collect_leaves(nodes, root)
    y: Dict[str, float] = {lf: i * leaf_step for i, lf in enumerate(leaves)}

    # Postorder to compute parent means
    order: List[str] = []
    stack: List[str] = [root]
    visited: Set[str] = set()
    while stack:
        u = stack.pop()
        order.append(u)
        for c in nodes[u].children:
            if c not in visited:
                stack.append(c)
        visited.add(u)
    order.reverse()

    for u in order:
        if nodes[u].children:
            y[u] = sum(y[c] for c in nodes[u].children) / len(nodes[u].children)
        elif u not in y:
            y[u] = 0.0
    return y


## Layout and display graph build

We map tree metrics to a display graph:
- Scale cumulative distances to pixels for X
- Create a non-overlapping set of vertical stems by spreading (quantized keys)
- Build orthogonal edges: parent→elbow→vertical→child
- Cache bend nodes and links to avoid duplicates
- Add aligned right-tip markers at a shared X

Outputs are two DataFrames: `nodes_df` (points) and `links_df` (edges).


In [13]:
# Layout — code
from typing import Any, Set


def build_display_graph(
    nodes: Dict[str, TNode],
    *,
    leaf_step: float = LEAF_Y_STEP_PX,
    parent_stub: float = PARENT_STUB_PX,
    tip_pad: float = TIP_PAD_PX,
    x_scale: float = X_SCALE_PX,
    min_level_gap: float = MIN_STEM_GAP_PX,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """Create node/link DataFrames for Cosmograph.

    Stages: compute cumulative X (scaled), spread stems to avoid overlap, build orthogonal
    edges with bend caching, and add right-aligned leaf markers.
    """
    def logv(*a: object) -> None:
        _log("[Layout]", *a)

    root = _find_root(nodes)
    dist = compute_cumdist(nodes, root)
    y = assign_y_equal_leaf_spacing(nodes, leaf_step)

    # 1) scale X by branch length
    dist_px: Dict[str, float] = {u: float(dist[u]) * float(x_scale) for u in nodes}

    # 2) compute stems and spread horizontally so verticals don't overlap
    raw_stems = sorted({dist_px[u] + parent_stub for u in nodes})
    spread_stems: List[float] = []
    last: Optional[float] = None
    for sx in raw_stems:
        spread = sx if last is None else max(sx, last + float(min_level_gap))
        spread_stems.append(spread)
        last = spread

    def q(v: float) -> float:
        return float(f"{v:.6f}")

    stem_map: Dict[float, float] = {q(o): s for o, s in zip(raw_stems, spread_stems)}

    def stem_x(u: str) -> float:
        return stem_map[q(dist_px[u] + parent_stub)]

    leaves = [k for k, v in nodes.items() if not v.children]

    pts: List[Dict[str, Any]] = []
    links: List[Dict[str, Any]] = []
    node_id_actual: Dict[str, int] = {}
    next_id = 0
    point_cache: Dict[Tuple[str, float, float], int] = {}
    link_cache: Set[Tuple[int, int]] = set()

    def add_point(kind: str, x: float, yv: float, *, label: str = "", color: Optional[str] = None,
                  size: Optional[float] = None, cache_bend: bool = True) -> int:
        nonlocal next_id
        xq, yq = q(x), q(yv)
        if cache_bend and kind == "bend":
            key = (kind, xq, yq)
            if key in point_cache:
                return point_cache[key]
        pid = next_id
        next_id += 1

        if size is None:
            if kind == "leaf":
                size = SIZE_LEAF_REAL
            elif kind == "leaf_marker":
                size = SIZE_LEAF_MARKER
            elif kind == "internal":
                size = SIZE_INTERNAL
            elif kind == "bend":
                size = SIZE_BEND
            else:
                size = 4.0

        if color is None:
            if kind == "leaf" or kind == "leaf_marker":
                color = COLOR_LEAF
            elif kind == "internal":
                color = COLOR_INTERNAL
            elif kind == "bend":
                color = COLOR_BEND
            else:
                color = COLOR_BEND

        pts.append({
            "id": int(pid),
            "x": xq,
            "y": yq,
            "size": size,
            "color": color,
            "label": label,
            "kind": kind,
        })
        if cache_bend and kind == "bend":
            point_cache[(kind, xq, yq)] = pid
        return pid

    def add_link(s: int, t: int, *, color: Optional[str] = None) -> None:
        key = (int(s), int(t))
        if key in link_cache:
            return
        link_cache.add(key)
        links.append({"source": int(s), "target": int(t), "color": color or COLOR_LINK})

    # 3) create real tree nodes at initial X; move children afterward
    for u, v in nodes.items():
        kind = "leaf" if not v.children else "internal"
        label = v.name if v.name else (u if kind == "leaf" else "")
        ex = q(stem_x(u))  # elbow x
        node_x_display = q(ex - float(PARENT_STUB_PX))
        pid = add_point(kind, node_x_display, y[u], label=label, cache_bend=False)
        node_id_actual[u] = pid
    id_to_idx = {row["id"]: i for i, row in enumerate(pts)}

    # 4) orthogonal edges; move children to final weighted X
    EPS = 1e-6
    for u, v in nodes.items():
        ex = q(stem_x(u))
        y_parent = y[u]
        for c in v.children:
            y_child = y[c]
            true_len_px = max(0.0, float(nodes[c].blen)) * float(x_scale)
            child_pid = node_id_actual[c]
            pts[id_to_idx[child_pid]]["x"] = q(ex + float(WEIGHTED_STUB_PX) + true_len_px)

            elbow_top = add_point("bend", ex, y_parent, cache_bend=True)
            add_link(node_id_actual[u], elbow_top)
            if abs(y_parent - y_child) > EPS:
                elbow_bot = add_point("bend", ex, y_child, cache_bend=True)
                add_link(elbow_top, elbow_bot)
                add_link(elbow_bot, child_pid)
            else:
                add_link(elbow_top, child_pid)

    # 5) aligned right guideline X and leaf markers
    max_leaf_x = max(pts[id_to_idx[node_id_actual[lf]]]["x"] for lf in leaves) if leaves else 0.0
    x_tipline = max_leaf_x + tip_pad
    for lf in leaves:
        leaf_row = pts[id_to_idx[node_id_actual[lf]]]
        pid = add_point(
            "leaf_marker", x_tipline, leaf_row["y"], label=leaf_row["label"], color=COLOR_LEAF,
            size=SIZE_LEAF_MARKER, cache_bend=False,
        )
        add_link(pid, leaf_row["id"])

    nodes_df = pd.DataFrame(pts)
    links_df = pd.DataFrame(links)
    logv(f"nodes_df={nodes_df.shape} links_df={links_df.shape} (markers @ x={x_tipline:.1f})")
    return nodes_df, links_df


## Rendering with Cosmograph

We map DataFrame columns to Cosmograph fields and disable the physics simulation:
- Coordinates are scaled (compact view), sizes are per-kind to emphasize leaves and markers
- Labels are shown only when hovered or dynamically when space allows
- Links are straight segments defined by source/target ids

This preserves interactivity for very large trees.


In [14]:
# Rendering — code

def render_cosmograph(nodes_df: pd.DataFrame, links_df: pd.DataFrame, *, link_px: float = LINK_WIDTH_PX):
    """Create a Cosmograph widget with scaled coordinates and per-kind sizes."""
    from cosmograph_widget import Cosmograph

    for c in ["id", "x", "y", "size", "color", "label"]:
        assert c in nodes_df.columns, f"nodes missing {c}"
    for c in ["source", "target", "color"]:
        assert c in links_df.columns, f"links missing {c}"

    nodes_df = nodes_df.copy()
    links_df = links_df.copy()
    nodes_df["id"] = nodes_df["id"].astype(int)
    nodes_df["x"] = nodes_df["x"].astype(float)
    nodes_df["y"] = nodes_df["y"].astype(float)
    nodes_df["size"] = (nodes_df["size"].astype(float) * NODE_SIZE_SCALE)
    links_df["source"] = links_df["source"].astype(int)
    links_df["target"] = links_df["target"].astype(int)

    # Compact view scaling
    scale_factor: float = 0.3
    nodes_df_scaled = nodes_df.copy()
    nodes_df_scaled["x"] = nodes_df_scaled["x"] * scale_factor
    nodes_df_scaled["y"] = nodes_df_scaled["y"] * scale_factor

    def size_by_kind(kind: str) -> float:
        if kind == "leaf_marker":
            return 24.0
        if kind == "leaf":
            return 18.0
        if kind == "internal":
            return 10.0
        if kind == "bend":
            return 3.0
        return 8.0

    nodes_df_scaled = nodes_df_scaled.copy()
    nodes_df_scaled["pixel_size"] = nodes_df_scaled["kind"].apply(size_by_kind).astype(float)

    w = Cosmograph(
        points=nodes_df_scaled,
        links=links_df,
        point_id_by="id",
        point_x_by="x",
        point_y_by="y",
        point_color_by="color",
        point_size_by="pixel_size",
        point_label_by="label",
        link_source_by="source",
        link_target_by="target",
        link_color_by="color",
        link_width_by=None,
        link_width=float(link_px) * 2,
        disable_simulation=True,
        fit_view_on_init=True,
        fit_view_padding=0.06,
        enable_drag=False,
        show_hovered_point_label=True,
        show_dynamic_labels=True,
        show_legends=False,
        scale_points_on_zoom=True,
    )
    return w


## I/O and run

We read a local Newick file path, take the first tree terminated by `;`, build the display graph, and render the widget. Adjust parameters if lines are cramped or too sparse.


In [15]:
# I/O and execution — code

def read_newick_input(path_or_text: str) -> str:
    """Read from local file if path exists; otherwise treat input as literal Newick."""
    if os.path.exists(path_or_text):
        with open(path_or_text, "r", encoding="utf-8", errors="ignore") as f:
            s = f.read()
        _log(f"[Input] Read file '{path_or_text}' len={len(s):,}")
        return s.strip()
    _log(f"[Input] Treating argument as literal Newick (len={len(path_or_text):,})")
    return path_or_text.strip()


def run_all(
    path_or_text: str,
    *,
    leaf_step: float = LEAF_Y_STEP_PX,
    parent_stub: float = PARENT_STUB_PX,
    tip_pad: float = TIP_PAD_PX,
    x_scale: float = X_SCALE_PX,
    min_level_gap: float = MIN_STEM_GAP_PX,
):
    """End-to-end: parse Newick, build display graph, render cosmograph.

    Returns widget and the two DataFrames for downstream inspection.
    """
    newick = read_newick_input(path_or_text)
    parts = [p.strip() for p in newick.split(";") if p.strip()]
    if not parts:
        raise ValueError("No Newick tree found.")
    tree_s = parts[0] + ";"
    _log(f"[Run] Using first tree substring len={len(tree_s)}")

    nodes = parse_newick(tree_s)
    nodes_df, links_df = build_display_graph(
        nodes,
        leaf_step=leaf_step,
        parent_stub=parent_stub,
        tip_pad=tip_pad,
        x_scale=x_scale,
        min_level_gap=min_level_gap,
    )
    _log("[Run] Rendering widget...")
    w = render_cosmograph(nodes_df, links_df)
    return w, nodes_df, links_df


# Local path and parameters (edit here)
NEWICK: str = "/Users/gushchin_a/Downloads/Chond 10Cal 10k TreeSet.tre"
params = dict(
    x_scale=140.0,
    min_level_gap=56.0,
    leaf_step=400.0,
    parent_stub=20.0,
    tip_pad=40.0,
)

w, nodes_df, links_df = run_all(NEWICK, **params)

try:
    from IPython.display import display
    display(w)
except Exception:
    print("Cosmograph widget created:", w)


[Input] Read file '/Users/gushchin_a/Downloads/Chond 10Cal 10k TreeSet.tre' len=549,947,364
[Run] Using first tree substring len=54968
[Parse] nodes=2,383 leaves=1,192 root=n1
[Layout] nodes_df=(6580, 7) links_df=(7147, 3) (markers @ x=98540.9)
[Run] Rendering widget...


Cosmograph(background_color=None, disable_simulation=True, enable_drag=False, fit_view_on_init=True, fit_view_…