# CosmoTree (Gtol)

A single-notebook pipeline to render large weighted phylogenetic tree with Cosmograph. The notebook:
- Parses Newick into a lightweight typed tree
- Computes cumulative branch-length X and equal-spaced leaf Y
- Lays out an orthogonal tree with elbows and non-overlapping stems
- Renders via `cosmograph_widget.Cosmograph` using scalable per-kind sizes

Performance rationale: Cosmograph handles tens of thousands of nodes interactively; matplotlib/similar packages are too slow for this scale.


In [None]:
# Environment setup
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Iterable
import re
import sys
import os
import pandas as pd
import numpy as np
import time


def _log(*args: object) -> None:
    print(*args, file=sys.stderr)

: 

In [None]:
# Configuration constants
# Geometry & appearance
X_SCALE_PX: float = 140.0  # px per branch-length unit
MIN_STEM_GAP_PX: float = 56.0  # min horizontal gap between adjacent vertical stems
PARENT_STUB_PX: float = 20.0  # elbow stub length before vertical
WEIGHTED_STUB_PX: float = 40.0  # minimal horizontal stub to weighted segment (set to 0 to preserve the ratio)
LEAF_Y_STEP_PX: float = 400.0  # vertical spacing between consecutive leaves
TIP_PAD_PX: float = 40.0  # extra space right of farthest leaf for markers

# Per-kind point sizes (in pixels for direct strategy)
SIZE_LEAF_MARKER: float = 20.0
SIZE_INTERNAL: float = 6.0
SIZE_BEND: float = 3.0
SIZE_LEAF_REAL: float = 8.0

# Global size scaling factor (applied before Cosmograph)
NODE_SIZE_SCALE: float = 2.0

# Colors
COLOR_LEAF: str = "#f5d76e"  # yellow
COLOR_INTERNAL: str = "#8ab4f8"  # light blue
COLOR_BEND: str = "#9aa0a6"  # gray
COLOR_LINK: str = "#97A1A9"  # gray
LINK_WIDTH_PX: float = 0.7

## Newick parsing

We tokenize Newick with a small (not really) regex, then build a typed node map using a stack:
- On `(` create an internal node and descend; on `)` assign pending attributes and ascend
- Leaf tokens become nodes attached to current parent
- `:length` applies to the last emitted node or is stored pending for the just-closed group
- If multiple roots appear, we synthesize a single `root0`

Complexity is linear in input length. Names and lengths are preserved; branch lengths < 0 are clamped later during layout.


In [None]:
# Newick parsing — code

@dataclass
class TNode:
    """A tree node parsed from Newick.

    Attributes:
        id: Stable synthetic identifier.
        name: Label (for leaves typically), may be empty.
        parent: Parent node id or None for root.
        blen: Branch length from parent to this node (non-negative real expected).
        children: Child node ids in insertion order.
    """
    id: str
    name: str = ""
    parent: Optional[str] = None
    blen: float = 0.0
    children: List[str] = field(default_factory=list)

_TOKEN_RE = re.compile(r"\s*([(),;])\s*|\s*([^(),:;]+)\s*|(\s*:\s*[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)")  # help me


def _tokenize(newick: str) -> Iterable[str]:
    """Yield tokens from Newick: parens, commas, semicolon, names, and ':len'."""
    i = 0
    s = newick.strip()
    while i < len(s):
        m = _TOKEN_RE.match(s, i)
        if not m:
            if s[i].isspace():
                i += 1
                continue
            raise ValueError(f"[Newick] Unexpected token near: {s[i:i+20]!r}")
        tok = m.group(1) or m.group(2) or m.group(3)
        if tok is not None:
            yield tok.strip()
        i = m.end()


def parse_newick(newick: str) -> Dict[str, TNode]:
    """Parse Newick text into a node-id -> TNode map.

    Rules:
    - Internal groups create anonymous nodes; leaves are named tokens.
    - ':len' attaches to the last emitted node or the just-closed group.
    - Multiple top-level groups are unified under synthetic 'root0'.
    """
    nodes: Dict[str, TNode] = {}
    stack: List[Optional[str]] = []
    last: Optional[str] = None
    nid = 0

    def new_id() -> str:
        nonlocal nid
        nid += 1
        return f"n{nid}"

    current_parent: Optional[str] = None
    pending_name: Optional[str] = None
    pending_len: Optional[float] = None

    for tok in _tokenize(newick):
        if tok == "(":
            u = new_id()
            nodes[u] = TNode(id=u)
            if current_parent is not None:
                nodes[current_parent].children.append(u)
                nodes[u].parent = current_parent
            stack.append(current_parent)
            current_parent = u
            last = None
        elif tok == ",":
            last = None
            pending_name = None
            pending_len = None
        elif tok == ")":
            if pending_name is not None:
                nodes[current_parent].name = pending_name
                pending_name = None
            if pending_len is not None:
                nodes[current_parent].blen = float(pending_len)
                pending_len = None
            current_parent = stack.pop()
            last = None
        elif tok == ";":
            break
        elif tok.startswith(":"):
            L = float(tok[1:].strip())
            if last is None:
                pending_len = L
            else:
                nodes[last].blen = L
        else:
            u = new_id()
            nodes[u] = TNode(id=u, name=tok, parent=current_parent)
            if current_parent is not None:
                nodes[current_parent].children.append(u)
            last = u
            pending_name = None
            pending_len = None

    roots = [k for k, v in nodes.items() if v.parent is None]
    if not roots:
        raise ValueError("[Parse] No root detected.")
    if len(roots) == 1:
        root_id = roots[0]
    else:
        root_id = "root0"
        nodes[root_id] = TNode(id=root_id, name="root", parent=None, blen=0.0, children=roots)
        for r in roots:
            nodes[r].parent = root_id

    _log(f"[Parse] nodes={len(nodes):,} leaves={sum(1 for v in nodes.values() if not v.children):,} root={root_id}")
    return nodes


## Tree utilities

We derive:
- Root detection by `parent is None`
- Leaf collection via DFS
- Child ordering to avoid crossings by sorting children by their minimal leaf name
- Cumulative branch length distances (X, later scaled)
- Equal-spaced leaf Y, then postorder parent Y as mean of children

All operations are linear in node count.


In [None]:
# Tree utilities — code
from typing import Set


def _collect_leaves(nodes: Dict[str, TNode], u: str) -> List[str]:
    """Return a list of leaf node ids under `u` (inclusive if `u` is a leaf)."""
    if not nodes[u].children:
        return [u]
    acc: List[str] = []
    for c in nodes[u].children:
        acc.extend(_collect_leaves(nodes, c))
    return acc


def _find_root(nodes: Dict[str, TNode]) -> str:
    """Return the single root id (node with `parent is None`)."""
    for k, v in nodes.items():
        if v.parent is None:
            return k
    raise ValueError("No root")


def _sort_children_for_no_crossing(nodes: Dict[str, TNode]) -> None:
    """Sort children of each internal node by the minimal leaf name to reduce crossings.
    
    Optimized: Single postorder pass to compute and cache min leaf names bottom-up.
    """
    root = _find_root(nodes)
    min_leaf_cache: Dict[str, str] = {}
    
    def compute_min_leaf_name(u: str) -> str:
        """Compute min leaf name, caching results for performance."""
        if u in min_leaf_cache:
            return min_leaf_cache[u]
        
        # If leaf, return name or fallback to ID
        if not nodes[u].children:
            name = nodes[u].name or u
            min_leaf_cache[u] = name
            return name
        
        # For internal nodes, min is the minimum of children's min
        child_mins = [compute_min_leaf_name(c) for c in nodes[u].children]
        result = min(child_mins)
        min_leaf_cache[u] = result
        return result
    
    # Postorder traversal: compute cache, then sort
    stack: List[str] = [root]
    while stack:
        u = stack.pop()
        if u not in min_leaf_cache:  # Not yet visited
            # If all children cached, compute this node's min
            if all(c in min_leaf_cache for c in nodes[u].children):
                compute_min_leaf_name(u)
                # Sort children by cached values
                if nodes[u].children:
                    nodes[u].children.sort(key=lambda c: min_leaf_cache[c])
                continue
            # Push current node back, then push children
            stack.append(u)
            for c in nodes[u].children:
                if c not in min_leaf_cache:
                    stack.append(c)


def compute_cumdist(nodes: Dict[str, TNode], root: Optional[str] = None) -> Dict[str, float]:
    """Compute cumulative branch length distance from `root` to each node."""
    if root is None:
        root = _find_root(nodes)
    dist: Dict[str, float] = {root: 0.0}
    stack: List[str] = [root]
    while stack:
        u = stack.pop()
        for c in nodes[u].children:
            dist[c] = dist[u] + max(0.0, float(nodes[c].blen))
            stack.append(c)
    return dist


def assign_y_equal_leaf_spacing(nodes: Dict[str, TNode], leaf_step: float) -> Dict[str, float]:
    """Assign Y such that leaves are equally spaced by `leaf_step`, parents at child mean.
    
    Optimized: Single DFS pass that assigns leaf indices incrementally and computes parent Y on way back up.
    """
    root = _find_root(nodes)
    _sort_children_for_no_crossing(nodes)
    
    y: Dict[str, float] = {}
    leaf_counter = 0
    
    def dfs_assign_y(u: str) -> None:
        """Recursively assign Y coordinates: leaves get indexed, parents get mean of children."""
        nonlocal leaf_counter
        if not nodes[u].children:
            # Leaf: assign next index
            y[u] = leaf_counter * leaf_step
            leaf_counter += 1
        else:
            # Internal: recursively process children, then compute mean
            for c in nodes[u].children:
                dfs_assign_y(c)
            y[u] = sum(y[c] for c in nodes[u].children) / len(nodes[u].children)
    
    dfs_assign_y(root)
    return y


## Layout and display graph build

We map tree metrics to a display graph:
- Scale cumulative distances to pixels for X
- Create a non-overlapping set of vertical stems by spreading (quantized keys)
- Build orthogonal edges: parent→elbow→vertical→child
- Cache bend nodes and links to avoid duplicates
- Add aligned right-tip markers at a shared X

Outputs are two DataFrames: `nodes_df` (points) and `links_df` (edges).


In [None]:
# Layout — code
from typing import Any, Set


def build_display_graph(
    nodes: Dict[str, TNode],
    *,
    leaf_step: float = LEAF_Y_STEP_PX,
    parent_stub: float = PARENT_STUB_PX,
    tip_pad: float = TIP_PAD_PX,
    x_scale: float = X_SCALE_PX,
    min_level_gap: float = MIN_STEM_GAP_PX,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """Create node/link DataFrames for Cosmograph.

    Stages: compute cumulative X (scaled), spread stems to avoid overlap, build orthogonal
    edges with bend caching, and add right-aligned leaf markers.
    
    Optimized: Use arrays and lists instead of list-of-dicts for much faster construction.
    """
    def logv(*a: object) -> None:
        _log("[Layout]", *a)

    t0 = time.time()
    root = _find_root(nodes)
    dist = compute_cumdist(nodes, root)
    y = assign_y_equal_leaf_spacing(nodes, leaf_step)
    logv(f"dist & y assignment: {time.time()-t0:.2f}s")

    # 1) scale X by branch length
    dist_px: Dict[str, float] = {u: float(dist[u]) * float(x_scale) for u in nodes}

    # 2) compute stems and spread horizontally so verticals don't overlap
    raw_stems = sorted({dist_px[u] + parent_stub for u in nodes})
    spread_stems: List[float] = []
    last: Optional[float] = None
    for sx in raw_stems:
        spread = sx if last is None else max(sx, last + float(min_level_gap))
        spread_stems.append(spread)
        last = spread

    def q(v: float) -> float:
        return float(f"{v:.6f}")

    stem_map: Dict[float, float] = {q(o): s for o, s in zip(raw_stems, spread_stems)}

    def stem_x(u: str) -> float:
        return stem_map[q(dist_px[u] + parent_stub)]

    leaves = [k for k, v in nodes.items() if not v.children]

    # Pre-allocate arrays/lists for performance
    num_nodes = len(nodes)
    max_links = num_nodes * 4  # Estimate: each node might add 2-3 links (tree + bends + markers)
    
    pts_id: List[int] = []
    pts_x: List[float] = []
    pts_y: List[float] = []
    pts_size: List[float] = []
    pts_color: List[str] = []
    pts_label: List[str] = []
    pts_kind: List[str] = []
    
    links_source: List[int] = []
    links_target: List[int] = []
    links_color: List[str] = []
    
    node_id_actual: Dict[str, int] = {}
    pid_to_idx: Dict[int, int] = {}  # Reverse lookup for O(1) access
    next_id = 0
    point_cache: Dict[Tuple[float, float], int] = {}  # Only (x,y) for bend cache
    link_cache: Set[Tuple[int, int]] = set()

    def get_size_for_kind(kind: str) -> float:
        if kind == "leaf":
            return SIZE_LEAF_REAL
        elif kind == "leaf_marker":
            return SIZE_LEAF_MARKER
        elif kind == "internal":
            return SIZE_INTERNAL
        elif kind == "bend":
            return SIZE_BEND
        else:
            return 4.0
    
    def get_color_for_kind(kind: str) -> str:
        if kind == "leaf" or kind == "leaf_marker":
            return COLOR_LEAF
        elif kind == "internal":
            return COLOR_INTERNAL
        elif kind == "bend":
            return COLOR_BEND
        else:
            return COLOR_BEND

    def add_point(kind: str, x: float, yv: float, *, label: str = "", 
                  color: Optional[str] = None, size: Optional[float] = None, 
                  cache_bend: bool = True) -> int:
        nonlocal next_id
        xq, yq = q(x), q(yv)
        if cache_bend and kind == "bend":
            key = (xq, yq)
            if key in point_cache:
                return point_cache[key]
        
        pid = next_id
        next_id += 1

        if size is None:
            size = get_size_for_kind(kind)
        if color is None:
            color = get_color_for_kind(kind)

        # Store reverse mapping for O(1) lookup
        pid_to_idx[pid] = len(pts_id)
        pts_id.append(pid)
        pts_x.append(xq)
        pts_y.append(yq)
        pts_size.append(size)
        pts_color.append(color)
        pts_label.append(label)
        pts_kind.append(kind)
        
        if cache_bend and kind == "bend":
            point_cache[(xq, yq)] = pid
        return pid

    def add_link(s: int, t: int, *, color: Optional[str] = None) -> None:
        key = (s, t)
        if key in link_cache:
            return
        link_cache.add(key)
        links_source.append(s)
        links_target.append(t)
        links_color.append(color or COLOR_LINK)

    # 3) create real tree nodes at initial X
    t1 = time.time()
    for u, v in nodes.items():
        kind = "leaf" if not v.children else "internal"
        label = v.name if v.name else (u if kind == "leaf" else "")
        ex = q(stem_x(u))  # elbow x
        node_x_display = q(ex - float(PARENT_STUB_PX))
        pid = add_point(kind, node_x_display, y[u], label=label, cache_bend=False)
        node_id_actual[u] = pid
    logv(f"tree nodes creation: {time.time()-t1:.2f}s")

    # 4) orthogonal edges; move children to final weighted X
    EPS = 1e-6
    t2 = time.time()
    for u, v in nodes.items():
        ex = q(stem_x(u))
        y_parent = y[u]
        for c in v.children:
            y_child = y[c]
            true_len_px = max(0.0, float(nodes[c].blen)) * float(x_scale)
            child_pid = node_id_actual[c]
            
            # Update x coordinate directly using O(1) lookup
            child_idx = pid_to_idx[child_pid]
            pts_x[child_idx] = q(ex + float(WEIGHTED_STUB_PX) + true_len_px)

            elbow_top = add_point("bend", ex, y_parent, cache_bend=True)
            add_link(node_id_actual[u], elbow_top)
            if abs(y_parent - y_child) > EPS:
                elbow_bot = add_point("bend", ex, y_child, cache_bend=True)
                add_link(elbow_top, elbow_bot)
                add_link(elbow_bot, child_pid)
            else:
                add_link(elbow_top, child_pid)
    logv(f"orthogonal edges: {time.time()-t2:.2f}s")

    # 5) aligned right guideline X and leaf markers
    t3 = time.time()
    max_leaf_x = 0.0
    if leaves:
        max_leaf_x = max(pts_x[pid_to_idx[node_id_actual[lf]]] for lf in leaves)
    x_tipline = max_leaf_x + tip_pad
    
    for lf in leaves:
        leaf_pid = node_id_actual[lf]
        leaf_idx = pid_to_idx[leaf_pid]
        leaf_y = pts_y[leaf_idx]
        leaf_label = pts_label[leaf_idx]
        
        pid = add_point(
            "leaf_marker", x_tipline, leaf_y, label=leaf_label, color=COLOR_LEAF,
            size=SIZE_LEAF_MARKER, cache_bend=False,
        )
        add_link(pid, leaf_pid)
    logv(f"leaf markers: {time.time()-t3:.2f}s")

    # Build DataFrames from arrays - much faster
    t4 = time.time()
    nodes_df = pd.DataFrame({
        "id": pts_id,
        "x": pts_x,
        "y": pts_y,
        "size": pts_size,
        "color": pts_color,
        "label": pts_label,
        "kind": pts_kind,
    })
    links_df = pd.DataFrame({
        "source": links_source,
        "target": links_target,
        "color": links_color,
    })
    logv(f"DataFrame creation: {time.time()-t4:.2f}s")
    logv(f"total: {time.time()-t0:.2f}s nodes_df={nodes_df.shape} links_df={links_df.shape} (markers @ x={x_tipline:.1f})")
    return nodes_df, links_df


## Rendering with Cosmograph

We map DataFrame columns to Cosmograph fields and disable the physics simulation:
- Coordinates are scaled (compact view), sizes are per-kind to emphasize leaves and markers
- Labels are shown only when hovered or dynamically when space allows
- Links are straight segments defined by source/target ids

This preserves interactivity for very large trees.


In [None]:
# Rendering — code

def render_cosmograph(nodes_df: pd.DataFrame, links_df: pd.DataFrame, *, link_px: float = LINK_WIDTH_PX):
    """Create a Cosmograph widget with scaled coordinates and per-kind sizes.
    
    Optimized: Reduce DataFrame copies, use vectorized operations.
    """
    from cosmograph_widget import Cosmograph

    for c in ["id", "x", "y", "size", "color", "label"]:
        assert c in nodes_df.columns, f"nodes missing {c}"
    for c in ["source", "target", "color"]:
        assert c in links_df.columns, f"links missing {c}"

    t0 = time.time()
    
    # Create mapping for size_by_kind - faster than apply()
    kind_to_size = {
        "leaf_marker": 24.0,
        "leaf": 18.0,
        "internal": 10.0,
        "bend": 3.0,
    }
    default_size = 8.0

    # Compact view scaling
    scale_factor: float = 0.3
    
    # Build final DataFrame with all transformations in one go
    t1 = time.time()
    nodes_df_scaled = pd.DataFrame({
        "id": nodes_df["id"].astype(int),
        "x": (nodes_df["x"].astype(float) * scale_factor),
        "y": (nodes_df["y"].astype(float) * scale_factor),
        "pixel_size": nodes_df["kind"].map(lambda k: kind_to_size.get(k, default_size)).astype(float),
        "color": nodes_df["color"],
        "label": nodes_df["label"],
    })
    _log(f"[Render] DataFrame prep: {time.time()-t1:.2f}s")
    
    t2 = time.time()
    links_df_prep = pd.DataFrame({
        "source": links_df["source"].astype(int),
        "target": links_df["target"].astype(int),
        "color": links_df["color"],
    })
    _log(f"[Render] Links prep: {time.time()-t2:.2f}s")
    
    t3 = time.time()
    w = Cosmograph(
        points=nodes_df_scaled,
        links=links_df_prep,
        point_id_by="id",
        point_x_by="x",
        point_y_by="y",
        point_color_by="color",
        point_size_by="pixel_size",
        point_label_by="label",
        link_source_by="source",
        link_target_by="target",
        link_color_by="color",
        link_width_by=None,
        link_width=float(link_px) * 2,
        disable_simulation=True,
        fit_view_on_init=True,
        fit_view_padding=0.06,
        enable_drag=False,
        show_hovered_point_label=True,
        show_dynamic_labels=True,
        show_legends=False,
        scale_points_on_zoom=True,
    )
    _log(f"[Render] Widget init: {time.time()-t3:.2f}s")
    _log(f"[Render] Total: {time.time()-t0:.2f}s")
    return w


## I/O and run

We read a local Newick file path, take the first tree terminated by `;`, build the display graph, and render the widget. Adjust parameters if lines are cramped or too sparse.


In [None]:
# I/O and execution — code

def read_newick_input(path_or_text: str) -> str:
    """Read from local file if path exists; otherwise treat input as literal Newick."""
    if os.path.exists(path_or_text):
        with open(path_or_text, "r", encoding="utf-8", errors="ignore") as f:
            s = f.read()
        _log(f"[Input] Read file '{path_or_text}' len={len(s):,}")
        return s.strip()
    _log(f"[Input] Treating argument as literal Newick (len={len(path_or_text):,})")
    return path_or_text.strip()


def run_all(
    path_or_text: str,
    *,
    leaf_step: float = LEAF_Y_STEP_PX,
    parent_stub: float = PARENT_STUB_PX,
    tip_pad: float = TIP_PAD_PX,
    x_scale: float = X_SCALE_PX,
    min_level_gap: float = MIN_STEM_GAP_PX,
):
    """End-to-end: parse Newick, build display graph, render cosmograph.

    Returns widget and the two DataFrames for downstream inspection.
    """
    newick = read_newick_input(path_or_text)
    parts = [p.strip() for p in newick.split(";") if p.strip()]
    if not parts:
        raise ValueError("No Newick tree found.")
    tree_s = parts[0] + ";"
    _log(f"[Run] Using first tree substring len={len(tree_s)}")

    nodes = parse_newick(tree_s)
    nodes_df, links_df = build_display_graph(
        nodes,
        leaf_step=leaf_step,
        parent_stub=parent_stub,
        tip_pad=tip_pad,
        x_scale=x_scale,
        min_level_gap=min_level_gap,
    )
    _log("[Run] Rendering widget...")
    w = render_cosmograph(nodes_df, links_df)
    return w, nodes_df, links_df


# Local path and parameters
NEWICK: str = "/Users/gushchin_a/Downloads/UShER SARS-CoV-2 latest.nwk"
params = dict[str, float](
    x_scale=140.0,
    min_level_gap=56.0,
    leaf_step=400.0,
    parent_stub=20.0,
    tip_pad=40.0,
)

w, nodes_df, links_df = run_all(NEWICK, **params)

try:
    from IPython.display import display
    display(w)
except Exception:
    print("Cosmograph widget created:", w)
