In [None]:
# -----------------------------------------------------------------------------
#  CosmoTree — minimal, *clean* viewer for weighted phylogenies in Cosmograph
#  (no duplicate imports / definitions; tweak constants in CONFIG only)
# -----------------------------------------------------------------------------

from __future__ import annotations

# ===== CONFIG (edit below) =====================================================
# Geometry & appearance
X_SCALE_PX = 140.0          # px per branch‑length unit (shrink ↘ to make nodes look bigger)
MIN_STEM_GAP_PX = 56.0      # min horizontal gap between adjacent vertical stems
PARENT_STUB_PX = 20.0       # length of elbow stub before vertical
LEAF_Y_STEP_PX = 400.0      # vertical spacing between consecutive leaves
TIP_PAD_PX = 40.0           # extra space right of farthest leaf for markers

# Per‑kind point sizes (in pixels for direct strategy)
SIZE_LEAF_MARKER = 20.0     # visible aligned leaf marker (drawn on right line)
SIZE_INTERNAL    = 6.0      # internal tree nodes
SIZE_BEND        = 3.0      # elbow / technical nodes
SIZE_LEAF_REAL   = 8.0      # weighted leaf dot at true X (0→ hidden)

# Global size scaling factor (applied before Cosmograph)
NODE_SIZE_SCALE = 2.0       # multiplier for all node sizes

# Colors
COLOR_LEAF      = "#f5d76e"
COLOR_INTERNAL  = "#8ab4f8"
COLOR_BEND      = "#9aa0a6"
COLOR_LINK      = "#97A1A9"
LINK_WIDTH_PX   = 0.7

# ==============================================================================
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import re, sys, os, pandas as pd

# --- helpers ------------------------------------------------------------------

def _log(*a):
    print(*a, file=sys.stderr)


# --- Robust Newick parser ----------------------------------------------------
import re

@dataclass
class TNode:
    id: str
    name: str = ""
    parent: Optional[str] = None
    blen: float = 0.0
    children: List[str] = field(default_factory=list)

_TOKEN_RE = re.compile(r"\s*([(),;])\s*|\s*([^(),:;]+)\s*|(\s*:\s*[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)")

def _tokenize(newick: str):
    i = 0
    s = newick.strip()
    while i < len(s):
        m = _TOKEN_RE.match(s, i)
        if not m:
            if s[i].isspace():
                i += 1
                continue
            raise ValueError(f"[Newick] Unexpected token near: {s[i:i+20]!r}")
        tok = m.group(1) or m.group(2) or m.group(3)
        if tok is not None:
            yield tok.strip()
        i = m.end()

def parse_newick(newick: str) -> Dict[str, TNode]:
    nodes: Dict[str, TNode] = {}
    stack: List[str] = []
    last: Optional[str] = None
    nid = 0

    def new_id():
        nonlocal nid
        nid += 1
        return f"n{nid}"

    current_parent: Optional[str] = None
    pending_name: Optional[str] = None
    pending_len: Optional[float] = None

    for tok in _tokenize(newick):
        if tok == "(":
            u = new_id()
            nodes[u] = TNode(id=u)
            if current_parent is not None:
                nodes[current_parent].children.append(u)
                nodes[u].parent = current_parent
            stack.append(current_parent)
            current_parent = u
            last = None
        elif tok == ",":
            last = None
            pending_name = None
            pending_len = None
        elif tok == ")":
            if pending_name is not None:
                nodes[current_parent].name = pending_name
                pending_name = None
            if pending_len is not None:
                nodes[current_parent].blen = float(pending_len)
                pending_len = None
            current_parent = stack.pop()
            last = None
        elif tok == ";":
            break
        elif tok.startswith(":"):
            L = float(tok[1:].strip())
            if last is None:
                pending_len = L
            else:
                nodes[last].blen = L
        else:
            u = new_id()
            nodes[u] = TNode(id=u, name=tok, parent=current_parent)
            if current_parent is not None:
                nodes[current_parent].children.append(u)
            last = u
            pending_name = None
            pending_len = None

    roots = [k for k, v in nodes.items() if v.parent is None]
    if not roots:
        raise ValueError("[Parse] No root detected.")
    if len(roots) == 1:
        root_id = roots[0]
    else:
        root_id = "root0"
        nodes[root_id] = TNode(id=root_id, name="root", parent=None, blen=0.0, children=roots)
        for r in roots:
            nodes[r].parent = root_id

    _log(f"[Parse] nodes={len(nodes):,} leaves={sum(1 for v in nodes.values() if not v.children):,} root={root_id}")
    return nodes

# --- Layout functions ---------------------------------------------------------

def _collect_leaves(nodes: Dict[str, TNode], u: str) -> List[str]:
    if not nodes[u].children:
        return [u]
    acc: List[str] = []
    for c in nodes[u].children:
        acc.extend(_collect_leaves(nodes, c))
    return acc

def _find_root(nodes: Dict[str, TNode]) -> str:
    for k, v in nodes.items():
        if v.parent is None:
            return k
    raise ValueError("No root")

def _sort_children_for_no_crossing(nodes: Dict[str, TNode]):
    def min_leaf_name(u: str) -> str:
        names = [nodes[x].name or x for x in _collect_leaves(nodes, u)]
        return min(names)
    for u, v in nodes.items():
        if v.children:
            v.children.sort(key=min_leaf_name)

def compute_cumdist(nodes: Dict[str, TNode], root: Optional[str] = None) -> Dict[str, float]:
    if root is None:
        root = _find_root(nodes)
    dist = {root: 0.0}
    stack = [root]
    while stack:
        u = stack.pop()
        for c in nodes[u].children:
            dist[c] = dist[u] + max(0.0, float(nodes[c].blen))
            stack.append(c)
    return dist

def assign_y_equal_leaf_spacing(nodes: Dict[str, TNode], leaf_step: float) -> Dict[str, float]:
    root = _find_root(nodes)
    _sort_children_for_no_crossing(nodes)
    # leaves equally spaced
    leaves = _collect_leaves(nodes, root)
    y = {lf: i * leaf_step for i, lf in enumerate(leaves)}

    # true postorder: children -> parent
    order = []
    stack = [root]
    visited = set()
    while stack:
        u = stack.pop()
        order.append(u)
        for c in nodes[u].children:
            if c not in visited:
                stack.append(c)
        visited.add(u)
    order.reverse()  # now children come before parents

    for u in order:
        if nodes[u].children:
            # all children already in y
            y[u] = sum(y[c] for c in nodes[u].children) / len(nodes[u].children)
        elif u not in y:
            y[u] = 0.0
    return y

def build_display_graph(
    nodes: Dict[str, TNode],
    leaf_step: float = LEAF_Y_STEP_PX,
    parent_stub: float = PARENT_STUB_PX,
    tip_pad: float = TIP_PAD_PX,
    x_scale: float = X_SCALE_PX,
    min_level_gap: float = MIN_STEM_GAP_PX,
):
    def logv(*a): _log("[Layout]", *a)

    root = _find_root(nodes)
    dist = compute_cumdist(nodes, root)                 # cumulative branch lengths (tree units)
    y    = assign_y_equal_leaf_spacing(nodes, leaf_step)

    # 1) scale X by branch length
    dist_px = {u: float(dist[u]) * float(x_scale) for u in nodes}

    # 2) compute stems and spread horizontally so verticals don't overlap
    raw_stems = sorted({dist_px[u] + parent_stub for u in nodes})
    spread_stems, last = [], None
    for x in raw_stems:
        spread = x if last is None else max(x, last + float(min_level_gap))
        spread_stems.append(spread); last = spread

    def q(v: float) -> float: return float(f"{v:.6f}")
    stem_map = {q(o): s for o, s in zip(raw_stems, spread_stems)}
    def stem_x(u: str) -> float: return stem_map[q(dist_px[u] + parent_stub)]

    leaves = [k for k, v in nodes.items() if not v.children]

    pts, links = [], []
    node_id_actual: Dict[str, int] = {}
    next_id = 0
    point_cache: Dict[Tuple[str, float, float], int] = {}
    link_cache: set[Tuple[int, int]] = set()

    def add_point(kind, x, yv, label="", color=None, size=None, cache_bend=True):
        nonlocal next_id
        xq, yq = q(x), q(yv)
        if cache_bend and kind == "bend":
            key = (kind, xq, yq)
            if key in point_cache:
                return point_cache[key]
        pid = next_id; next_id += 1

        if size is None:
            if kind == "leaf":
                size = SIZE_LEAF_REAL
            elif kind == "leaf_marker":
                size = SIZE_LEAF_MARKER
            elif kind == "internal":
                size = SIZE_INTERNAL
            elif kind == "bend":
                size = SIZE_BEND
            else:
                size = 4.0

        if color is None:
            if kind == "leaf" or kind == "leaf_marker": color = COLOR_LEAF
            elif kind == "internal":  color = COLOR_INTERNAL
            elif kind == "bend":      color = COLOR_BEND
            else:                     color = COLOR_BEND

        pts.append({
            "id": int(pid),
            "x": xq, "y": yq,
            "size": size,
            "color": color,
            "label": label,
            "kind": kind,
        })
        if cache_bend and kind == "bend":
            point_cache[(kind, xq, yq)] = pid
        return pid

    def add_link(s, t, color=None):
        key = (int(s), int(t))
        if key in link_cache: return
        link_cache.add(key)
        links.append({"source": int(s), "target": int(t), "color": color or COLOR_LINK})

    # 3) create real tree nodes at initial X; we'll move children afterward
    for u, v in nodes.items():
        kind = "leaf" if not v.children else "internal"
        label = v.name if v.name else (u if kind == "leaf" else "")
        pid = add_point(kind, dist_px[u], y[u], label=label, cache_bend=False)
        node_id_actual[u] = pid
    id_to_idx = {row["id"]: i for i, row in enumerate(pts)}

    # 4) orthogonal edges; move children to final weighted X = spread stem + true length
    EPS = 1e-6
    for u, v in nodes.items():
        sx = stem_x(u); y_parent = y[u]
        k = len(v.children)
        for c in v.children:
            y_child = y[c]
            true_len_px = max(0.0, float(nodes[c].blen)) * float(x_scale)
            child_pid = node_id_actual[c]
            pts[id_to_idx[child_pid]]["x"] = q(sx + true_len_px)

            if k == 1 and abs(y_parent - y_child) < EPS:
                add_link(node_id_actual[u], child_pid)
            else:
                elbow_top = add_point("bend", sx, y_parent, cache_bend=True)
                if abs(y_parent - y_child) > EPS:
                    elbow_bot = add_point("bend", sx, y_child, cache_bend=True)
                    add_link(node_id_actual[u], elbow_top)
                    add_link(elbow_top, elbow_bot)
                    add_link(elbow_bot, child_pid)
                else:
                    add_link(node_id_actual[u], elbow_top)
                    add_link(elbow_top, child_pid)

    # 5) compute shared right guideline X and add **aligned leaf markers** (no links)
    max_leaf_x = max(pts[id_to_idx[node_id_actual[lf]]]["x"] for lf in leaves) if leaves else 0.0
    x_tipline  = max_leaf_x + tip_pad
    for lf in leaves:
        leaf_row = pts[id_to_idx[node_id_actual[lf]]]
        pid = add_point("leaf_marker", x_tipline, leaf_row["y"],
                  label=leaf_row["label"], color=COLOR_LEAF, size=SIZE_LEAF_MARKER, cache_bend=False)
        add_link(pid, leaf_row["id"])

    nodes_df = pd.DataFrame(pts)
    links_df = pd.DataFrame(links)
    logv(f"nodes_df={nodes_df.shape} links_df={links_df.shape} (markers @ x={x_tipline:.1f})")
    return nodes_df, links_df

# --- Rendering ----------------------------------------------------------------

def render_cosmograph(nodes_df, links_df, *, link_px: float = LINK_WIDTH_PX):
    from cosmograph_widget import Cosmograph
    # Minimal validation + dtypes
    for c in ["id","x","y","size","color","label"]:
        assert c in nodes_df.columns, f"nodes missing {c}"
    for c in ["source","target","color"]:
        assert c in links_df.columns, f"links missing {c}"

    nodes_df = nodes_df.copy()
    links_df = links_df.copy()
    nodes_df["id"] = nodes_df["id"].astype(int)
    nodes_df["x"]  = nodes_df["x"].astype(float)
    nodes_df["y"]  = nodes_df["y"].astype(float)
    # Apply global size scaling
    nodes_df["size"] = (nodes_df["size"].astype(float) * NODE_SIZE_SCALE)
    links_df["source"] = links_df["source"].astype(int)
    links_df["target"] = links_df["target"].astype(int)

    # Simple approach that was working - use fixed size for all nodes
    print("Creating Cosmograph with fixed node size...")
    
    w = Cosmograph(
        points=nodes_df,
        links=links_df,
        point_id_by="id",
        point_x_by="x",
        point_y_by="y",
        point_color_by="color",
        point_size=40,  # Fixed size for all nodes - larger than before
        point_label_by="label",
        link_source_by="source",
        link_target_by="target",
        link_color_by="color",
        link_width_by=None,
        link_width=float(link_px),
        disable_simulation=True,
        fit_view_on_init=True,
        fit_view_padding=0.06,
        enable_drag=False,
        show_hovered_point_label=True,
        show_dynamic_labels=True,
        show_legends=False,
        scale_points_on_zoom=True,
    )
    print("✓ Cosmograph created with fixed size=40")
    return w

def read_newick_input(path_or_text: str) -> str:
    if os.path.exists(path_or_text):
        with open(path_or_text, "r", encoding="utf-8", errors="ignore") as f:
            s = f.read()
        _log(f"[Input] Read file '{path_or_text}' len={len(s):,}")
        return s.strip()
    else:
        _log(f"[Input] Treating argument as literal Newick (len={len(path_or_text):,})")
        return path_or_text.strip()

def run_all(path_or_text: str,
            leaf_step: float = LEAF_Y_STEP_PX,
            parent_stub: float = PARENT_STUB_PX,
            tip_pad: float = TIP_PAD_PX,
            **layout_kwargs):
    newick = read_newick_input(path_or_text)
    parts = [p.strip() for p in newick.split(";") if p.strip()]
    if not parts:
        raise ValueError("No Newick tree found.")
    tree_s = parts[0] + ";"
    _log(f"[Run] Using first tree substring len={len(tree_s)}")
    nodes = parse_newick(tree_s)
    nodes_df, links_df = build_display_graph(
        nodes,
        leaf_step=leaf_step,
        parent_stub=parent_stub,
        tip_pad=tip_pad,
        **layout_kwargs
    )
    _log("[Run] Rendering widget...")
    w = render_cosmograph(nodes_df, links_df)
    return w, nodes_df, links_df

# --- Main execution -----------------------------------------------------------

# 1) Put the path to your Newick file here:
NEWICK = "/Users/gushchin_a/Downloads/Chond 10Cal 10k TreeSet.tre"

# 2) Start with sensible defaults. If lines are cramped horizontally, raise x_scale.
#    If vertical trunks sit too close, raise min_level_gap. For many leaves, you
#    might lower leaf_step a bit.
params = dict(
    x_scale=140.0,          # pixels per branch-length unit
    min_level_gap=56.0,     # minimal horizontal gap between stems
    leaf_step=400.0,        # vertical spacing between leaves
    parent_stub=20.0,       # length of elbow stub
    tip_pad=40.0            # extra space to the right of the farthest leaf
)

w, nodes_df, links_df = run_all(NEWICK, **params)

# Use display if available (Jupyter), otherwise print
try:
    from IPython.display import display
    display(w)
except ImportError:
    print("Cosmograph widget created:", w)


[Input] Read file '/Users/gushchin_a/Downloads/Chond 10Cal 10k TreeSet.tre' len=549,947,364


Creating Cosmograph with different node sizes by type...
✓ Cosmograph created with variable node sizes


[Run] Using first tree substring len=54968
[Parse] nodes=2,383 leaves=1,192 root=n1
[Layout] nodes_df=(6580, 7) links_df=(7147, 3) (markers @ x=98500.9)
[Run] Rendering widget...


Cosmograph(background_color=None, disable_simulation=True, enable_drag=False, fit_view_on_init=True, fit_view_…