<a href="https://colab.research.google.com/github/mmore500/hstrat-synthesis/blob/main/pylib/track_outlier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [47]:
!pip install hstrat downstream iplotx teeplot



In [48]:
!wget https://raw.githubusercontent.com/mmore500/hstrat-synthesis/refs/heads/main/pylib/track_ca.py

--2026-01-30 06:53:00--  https://raw.githubusercontent.com/mmore500/hstrat-synthesis/refs/heads/main/pylib/track_ca.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 16892 (16K) [text/plain]
Saving to: ‘track_ca.py.2’


2026-01-30 06:53:00 (141 MB/s) - ‘track_ca.py.2’ saved [16892/16892]



In [49]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import itertools as it
import base64

In [4]:
import cupy as xp  # should be possible to substitute cupy
xp.__version__

'13.6.0'

In [5]:
rle = "ERETQB4eHWkQ7xD4eYZosBQZFixOBHmtFeehExrKVhURLRAqGxeIlSO1JYZP6DRi69rop7TQCkvWTIag7kAS8g"
padded_bytes = base64.b64decode(rle + "==")
padded_bits = "".join(f"{b:08b}" for b in padded_bytes)
rule_table = xp.fromiter(map(int, padded_bits), dtype=xp.uint16)
rule_table

array([0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
       1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1,
       1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0,
       0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0,
       1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1,
       1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1,
       0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0,
       1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0,
       0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1,
       0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1,
       0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1,
       0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1,
       0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0,
       1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0,

In [6]:
def pack_neighborhoods(cells : xp.ndarray) -> xp.ndarray:
  res = xp.zeros_like(cells, dtype=xp.uint16)
  for i, (dy, dx) in enumerate(it.product((-1, 0, 1), repeat=2)):
      res |= xp.roll(cells << i, shift=(-dy, -dx), axis=(0, 1))
  return res


In [7]:
def test_pack_neighborhoods():
    # --- Case 1: The "Center" bit (Bit 4) ---
    # If a cell is alone, it should only see itself.
    grid = xp.zeros((3, 3), dtype=xp.uint16)
    grid[1, 1] = 1
    packed = pack_neighborhoods(grid)

    assert xp.array_equal(
        packed,
        xp.array(
            [
                [2**8, 2**7, 2**6],
                [2**5, 2**4, 2**3],
                [2**2, 2**1, 2**0],
            ],
            dtype=xp.uint16,
        ),
    )


    # --- Case 2: Neighbor Perspective ---
    # The cell at (1, 0) is to the WEST of the center.
    # Therefore, looking EAST, it sees the active center.
    # Bit 5 is East. 1 << 5 = 32.
    assert packed[1, 0] == 32, f"Expected 32 (East), got {packed[1, 0]}"

    # --- Case 3: Toroidal Wrap (Corner Check) ---
    grid = xp.zeros((3, 3), dtype=xp.uint16)
    grid[0, 0] = 1
    packed = pack_neighborhoods(grid)

    # Check Bottom-Right cell (2,2).
    # From (2,2), looking South-East wraps to (0,0).
    # Bit 8 (SE) is 256. This assertion was already correct.
    assert packed[2, 2] == 256, f"Expected 256 (SE wrap), got {packed[2, 2]}"

    # Check Top-Right cell (0, 2).
    # From (0,2), looking EAST (Right) wraps to (0,0).
    # Bit 5 is East. 1 << 5 = 32.
    assert packed[0, 2] == 32, f"Expected 32 (East wrap), got {packed[0, 2]}"

    # --- Case 4: Full Grid ---
    # All neighbors active = 9 bits set (binary 111111111 = 511)
    grid = xp.ones((3, 3), dtype=xp.uint16)
    packed = pack_neighborhoods(grid)
    assert xp.all(packed == 511)

    print("All assertions passed!")

test_pack_neighborhoods()

All assertions passed!


In [8]:
def update(cells : xp.ndarray, rule: xp.ndarray) -> xp.ndarray:
  # prevent wraparound
  cells[:1], cells[-1:], cells[:, :1], cells[:, -1:] = 0, 0, 0, 0
  # apply rule
  return rule[pack_neighborhoods(cells)]


In [9]:
def parse_rle(rle_string: str) -> xp.ndarray:
  rows = []
  current_row = []
  count = ""
  max_row_length = 0

  for char in rle_string:
    if char.isdigit():
      count += char
    elif char == "b":
      num_cells = int(count) if count else 1
      current_row.extend([0] * num_cells)
      count = ""
    elif char == "o":
      num_cells = int(count) if count else 1
      current_row.extend([1] * num_cells)
      count = ""
    elif char == "$":
      rows.append(current_row)
      max_row_length = max(max_row_length, len(current_row))
      current_row = []
      count = ""
    elif char == "!":
      break

  # Append the last row if it's not empty and the string didn't end with a '$'
  if current_row:
      rows.append(current_row)
      max_row_length = max(max_row_length, len(current_row))

  # Pad rows to ensure uniform length
  padded_rows = []
  for row in rows:
    padded_rows.append(row + [0] * (max_row_length - len(row)))

  return xp.array(padded_rows, dtype=xp.uint16)


In [10]:
rle_pattern_string = "$bo$3o$2bo!"
parsed_pattern = parse_rle(rle_pattern_string)
print(parsed_pattern)

[[0 0 0]
 [0 1 0]
 [1 1 1]
 [0 0 1]]


In [50]:
grid_size = 2300
grid = xp.zeros((grid_size, grid_size), dtype=xp.uint16)

pattern_height, pattern_width = parsed_pattern.shape

offset_y = (grid_size - pattern_height) // 2
offset_x = (grid_size - pattern_width) // 2

grid[
    offset_y : offset_y + pattern_height,
    offset_x : offset_x + pattern_width
] = parsed_pattern


In [None]:
from track_ca import MultiTracker
from downstream import dstream
import tqdm

tracker = MultiTracker(
    dstream_algo=dstream.steady_algo, dstream_S=64, dstream_bitwidth=1, N=15, backend=xp
).initialize(grid.astype(xp.bool_), R=1)

num_steps = 15000
for _ in tqdm.trange(num_steps):
    grid = update(grid, rule_table)
    tracker.step(grid.astype(xp.bool_), add_random_noise=True)

print(f"\nTotal states stored: {tracker.N}")

  2%|▏         | 281/15000 [00:43<36:24,  6.74it/s]

In [None]:
phylos = tracker.reconstruct_phylogenies()

In [None]:
# since this holds, we don't need to do any weird re-ordering of the mrca vector
for phylo in phylos:
    assert phylo[phylo["extant"] == True]["dstream_data_id"].is_monotonic_increasing

In [None]:
import hstrat
from hstrat._auxiliary_lib import (
    alifestd_mark_node_depth_asexual,
    alifestd_try_add_ancestor_id_col,
)

post_phylos = [
    alifestd_mark_node_depth_asexual(alifestd_try_add_ancestor_id_col(phylo))
    for phylo in phylos
]
for phylo in post_phylos:
    phylo["extant"] = phylo["extant"] == True

decomposed_phylos = [
    {
        k: xp.array(phylo[k].to_numpy()) for k in [
            "extant",
            "dstream_data_id",
            "node_depth",
            "ancestor_id",
            "id",
            "hstrat_rank"
        ]
    } for phylo in post_phylos
]
for phylo in decomposed_phylos:
    phylo["max_depth"] = int(phylo["node_depth"].max())
    phylo["n"] = len(phylo["id"])

extant_ids = decomposed_phylos[0]["dstream_data_id"][decomposed_phylos[0]["extant"]]

mrca_variances = xp.empty(len(extant_ids), dtype=xp.float32)

# phylowise_mrcas = xp.empty((len(extant_ids), len(phylos), len(extant_ids)), dtype=xp.int16)

for j, extant_id in enumerate(tqdm.tqdm(extant_ids)):
    result = xp.full((len(decomposed_phylos), len(extant_ids)), -1, dtype=xp.int64)
    for i, phylo in enumerate(decomposed_phylos):
        cur_positions = phylo["id"].copy()
        real_id = cur_positions[phylo["dstream_data_id"] == extant_id][0]
        temp = xp.full(phylo["n"], -1, dtype=xp.int64)
        target_depth = phylo["node_depth"][real_id]
        for depth in reversed(range(phylo["max_depth"] + 1)):
            if depth <= target_depth:
                target_position = cur_positions[real_id]
                temp = xp.maximum(
                    temp,
                    xp.where(
                        cur_positions == target_position,
                        target_position,
                        -1,
                    ),
                    dtype=xp.int64,
                )
            depth_mask = phylo["node_depth"][cur_positions] == depth
            cur_positions[depth_mask] = phylo["ancestor_id"][cur_positions[depth_mask]]
        result[i] = phylo["hstrat_rank"][temp[phylo["extant"]]]
        # phylowise_mrcas[j][i] = phylo["id"][temp[phylo["extant"]]]
    means = result.mean(axis=0)
    mrca_variances[j] = (xp.square(result - means).mean(axis=0) / xp.square(means)).mean()

In [46]:
import pickle
with open("/content/drive/MyDrive/mrca_variance.pkl", 'wb') as f:
    pickle.dump(mrca_variances, f)
with open("/content/drive/MyDrive/phylos.pkl", 'wb') as f:
    pickle.dump(phylos, f)