In [None]:
%load_ext watermark


In [None]:
import os

import alifedata_phyloinformatics_convert as apc
import dendropy as dp
from hstrat import _auxiliary_lib as hstrat_aux
import iplotx as ipx
from IPython.display import display
from matplotlib.colors import to_hex
from matplotlib.patches import ConnectionPatch
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import seaborn as sns
from teeplot import teeplot as tp

import pylib  # noqa: F401


In [None]:
%watermark -diwmuv -iv


In [None]:
teeplot_subdir = os.environ.get("NOTEBOOK_NAME", "2025-10-22-gosper_wse-gol")
teeplot_subdir


## Prep Data


In [None]:
df = pl.concat(
    [
        pl.read_parquet("https://osf.io/b7e8t/download").with_columns(surf=0),
        pl.read_parquet("https://osf.io/2k78y/download").with_columns(surf=1),
        pl.read_parquet("https://osf.io/6cmxa/download").with_columns(surf=2),
    ],
).to_pandas()

display(df.describe()), display(df.head()), display(df.tail());


In [None]:
df.columns


In [None]:
df = pd.concat(
    [hstrat_aux.alifestd_mark_root_id(dfs) for _, dfs in df.groupby("surf")],
    ignore_index=True,
)
df["taxon_label"] = df["id"]


## Example Plot


In [None]:
trees_0 = apc.alife_dataframe_to_dendropy_trees(
    hstrat_aux.alifestd_try_add_ancestor_list_col(df[df["surf"] == 0]),
    setup_edge_lengths=True,
)
trees_0


In [None]:
trees_1 = apc.alife_dataframe_to_dendropy_trees(
    hstrat_aux.alifestd_try_add_ancestor_list_col(df[df["surf"] == 1]),
    setup_edge_lengths=True,
)
trees_1


In [None]:
trees_2 = apc.alife_dataframe_to_dendropy_trees(
    hstrat_aux.alifestd_try_add_ancestor_list_col(df[df["surf"] == 2]),
    setup_edge_lengths=True,
)
trees_2


In [None]:
grid_dim = int(df[df["gol_state"] >= 0][["row", "col"]].to_numpy().max()) + 1
grid_dim


In [None]:
for i, target in enumerate([trees_0, trees_1, trees_2]):
    with tp.teed(
        plt.subplots,
        2,
        2,
        figsize=(10, 10),
        gridspec_kw={
            'width_ratios': [0.2, 0.8],
            'height_ratios': [0.2, 0.8],
            'wspace': 0.05,
            'hspace': 0.05,
        },
        teeplot_outattrs=dict(surf=i),
        teeplot_subdir=teeplot_subdir,
    ) as teed:
        fig, ((ax_corner, ax_top), (ax_left, ax_grid)) = teed

        grid = np.full((grid_dim, grid_dim), np.nan)
        for _, row in df[(df["surf"] == i) & (df["gol_state"] >= 0)].iterrows():
            grid[int(row["row"]), int(row["col"])] = row["root_id"]

        ax_corner.axis('off')

        cmap = sns.color_palette(["Dark2", "Set1", "tab10_r"][i], len(target))

        tree_top = ipx.plotting.tree(
            target[0],
            ax=ax_top,
            layout="vertical",
            edge_color=to_hex(cmap[0]),
            edge_linewidth=1.5,
            margins=0.0,
        )
        ax_top.margins(x=-0.04)
        ax_top.set_xlim(ax_top.get_xlim()[0] - 10, None)

        tree_left = ipx.plotting.tree(
            target[1],
            ax=ax_left,
            edge_color=to_hex(cmap[1]),
            edge_linewidth=1.5,
            margins=0.0,
        )
        ax_left.invert_yaxis()
        ax_left.margins(y=-0.05)

        sns.heatmap(
            grid,
            ax=ax_grid,
            cmap=cmap,
            vmin=0,
            vmax=len(target) - 1,
            cbar=False,
        )
        ax_grid.set_axis_off()

        dfi = df[df["surf"] == i].copy()
        dfi = hstrat_aux.alifestd_mark_leaves(dfi, mutate=True)

        for idx, row in dfi[dfi["is_leaf"]].iterrows():
            axis = [ax_top, ax_left, None, None][row["root_id"]]
            tree = [tree_top, tree_left, None, None][row["root_id"]]
            if axis is None:
                continue
            tree_x, tree_y = next(
                v for n, v in tree.get_layout().T.items()
                if n.taxon is not None and n.taxon.label == row["id"]
            )
            grid_x, grid_y = row["col"], row["row"]

            # draw line between axes from (tree_x, tree_y) to (grid_x, grid_y)
            con = ConnectionPatch(
                xyA=(tree_x, tree_y),
                xyB=(grid_x, grid_y),
                coordsA=axis.transData,
                coordsB=ax_grid.transData,
                color="gray",
                linestyle="--",
                alpha=0.5,
                linewidth=0.5,
                clip_on=False,
            )
            fig.add_artist(con)
