## Trees, Trees, more Trees, oh my!!

unrooted phylogenetic trees without pendant edges


### Imports

In [1]:
import numpy as np
import pytest

import geomstats.backend as gs
from geomstats.geometry.stratified.bhv_space import (
    Tree,
    TreeSpace,
    TreeTopology,
)
from geomstats.geometry.stratified.trees import Split
from geomstats.learning.frechet_mean import FrechetMean

gs.random.seed(666)

### TreeTopology tests

In [2]:
## Empty split not allowed
illegal_empty_topology = (Split({}, {0, 1, 2, 3, 4}), Split({0, 1}, {2, 3, 4}))
expected_error_message = "Empty splits like .* are not allowed."

with pytest.raises(ValueError, match=expected_error_message):
    TreeTopology(illegal_empty_topology)

## Singleton split not allowed
illegal_pendant_topology = (Split({3}, {0, 4, 1, 2}), Split({0, 1}, {2, 3, 4}))
expected_error_message = "Pendant edges / singleton splits like .* are not allowed."

with pytest.raises(ValueError, match=expected_error_message):
    TreeTopology(illegal_pendant_topology)

## This one is fine
legal_topology = (Split({3, 4}, {0, 1, 2}), Split({0, 1}, {2, 3, 4}))
legal_tt = TreeTopology(legal_topology)

assert legal_tt.splits == legal_topology
assert legal_tt.n_labels == 5

## Star tree should work
star_tree_topology = ()
star_tt = TreeTopology(star_tree_topology, n_labels=5)

assert star_tt.splits == star_tree_topology
assert star_tt.n_labels == 5

### TreeTopology tests

In [5]:
legal_lengths = [2, 4]
illegal_zero_length = [0, 3]
illegal_negative_length = [-2, 3]
illegal_number_of_lengths = [1, 2, 3]
star_tree_lengths = []

legal_topology = (Split({3, 4}, {0, 1, 2}), Split({0, 1}, {2, 3, 4}))
illegal_empty_topology = (Split({}, {0, 1, 2, 3, 4}), Split({0, 1}, {2, 3, 4}))
illegal_pendant_topology = (Split({3}, {0, 4, 1, 2}), Split({0, 1}, {2, 3, 4}))
star_tree_topology = ()

## Empty split not allowed
expected_error_message = "Empty splits like .* are not allowed."
with pytest.raises(ValueError, match=expected_error_message):
    Tree(illegal_empty_topology, legal_lengths)

## Singleton split not allowed
expected_error_message = "Pendant edges / singleton splits like .* are not allowed."
with pytest.raises(ValueError, match=expected_error_message):
    Tree(illegal_pendant_topology, legal_lengths)

## Zero length not allowed
expected_error_message = "Lengths must be positive. .* is not allowed."
with pytest.raises(ValueError, match=expected_error_message):
    Tree(legal_topology, illegal_zero_length)

## Negative length not allowed
expected_error_message = "Lengths must be positive. .* is not allowed."
with pytest.raises(ValueError, match=expected_error_message):
    Tree(legal_topology, illegal_negative_length)

## Lengths must match number of splits
expected_error_message = "Must have same number of edge lengths as edges. .*"
with pytest.raises(ValueError, match=expected_error_message):
    Tree(legal_topology, illegal_number_of_lengths)

## This one is fine
tree = Tree(legal_topology, legal_lengths)
assert len(tree.lengths) == len(legal_lengths)

## Star tree should work
star_tree = Tree(star_tree_topology, star_tree_lengths, n_labels=5)
assert len(star_tree.lengths) == 0

### TreeSpace tests

In [2]:
# BHV space only defined with N >= 4
expected_error_message = "BHV space only defined for N >= 4.*"
with pytest.raises(ValueError, match=expected_error_message):
    TreeSpace(n_labels=2)

# This is fine.
bhv5 = TreeSpace(n_labels=5)

# Make random point and check if belongs in space
random_tree = bhv5.random_point()
assert bhv5.belongs(random_tree)

# Check handling of non-tree
assert not bhv5.belongs("imposter")

# make random point in other space and make sure it doesn't belong
bhv7 = TreeSpace(n_labels=7)
assert not bhv7.belongs(random_tree)

### BHVMetric tests

https://plewis.github.io/applets/bhvspace/

^^ v helpful for visualisation and test creation

In [None]:
bhv5 = TreeSpace(n_labels=5)

top_quartile_topology = (Split({3, 4}, {0, 1, 2}), Split({0, 1}, {2, 3, 4}))
left_quartile_topology = (Split({0, 4}, {3, 1, 2}), Split({2, 3}, {0, 1, 4}))
right_quartile_topology = (Split({2, 3}, {0, 1, 4}), Split({2, 3}, {0, 1, 4}))
lengths = gs.array([3, 1])

top_tree = Tree(top_quartile_topology, lengths)
top_tree_shifted = Tree(top_quartile_topology, lengths + 1)
bottom_tree = Tree(right_quartile_topology, lengths + 1)
mirrored_left_tree = Tree(left_quartile_topology, lengths)
halfway_tree = Tree(top_quartile_topology, [0.5, 0.5])

## Distance same quadrant
dist = bhv5.metric.dist(top_tree, top_tree_shifted)
expected_dist = 2**0.5
assert gs.abs(dist - expected_dist) < gs.atol

## Distance different quadrant

## Distance different quadrant, geodesic through origin


## Geodesic same quadrant

## Geodesic different quadrant
geodesic = bhv5.metric.geodesic(top_tree, bottom_tree)
halfway = geodesic(0.5)
assert halfway.equal(halfway_tree)

## Geodesic different quadrant, geodesic through origin
geodesic = bhv5.metric.geodesic(top_tree, mirrored_left_tree)
geodesic(0.5)

{({0, 1}, {2, 3, 4}): np.int64(1), ({0, 1, 2}, {3, 4}): np.int64(3)}
{({0, 1}, {2, 3, 4}): np.int64(2), ({0, 1, 2}, {3, 4}): np.int64(4)}
1.4142135623730951
{({0, 1}, {2, 3, 4}): np.int64(1)}
{({0, 1}, {2, 3, 4}): 0}
first splits_t {({0, 1}, {2, 3, 4}): np.float64(0.5)}
supports {(2, 3, 4): (((({0, 1, 2}, {3, 4}),),), ((({0, 1, 4}, {2, 3}),),))}
index 0
splits_t_a {({0, 1, 2}, {3, 4}): np.float64(0.5)}
splits_t_b {}
splits_t {({0, 1}, {2, 3, 4}): np.float64(0.5), ({0, 1, 2}, {3, 4}): np.float64(0.5)}
{}
{}
first splits_t {}
supports {(0, 1, 2, 3, 4): (((({0, 1}, {2, 3, 4}), ({0, 1, 2}, {3, 4})),), ((({0, 1, 4}, {2, 3}), ({0, 4}, {1, 2, 3})),))}
index 0
splits_t_a {({0, 1}, {2, 3, 4}): np.float64(0.0), ({0, 1, 2}, {3, 4}): np.float64(0.0)}
splits_t_b {}
splits_t {({0, 1}, {2, 3, 4}): np.float64(0.0), ({0, 1, 2}, {3, 4}): np.float64(0.0)}


TypeError: 'NoneType' object cannot be interpreted as an integer

In [2]:
trees = np.array(
    [generate_random_tree(5, only_internal_edges=True, p_keep=1) for i in range(30)]
)

In [2]:
bhv_space = TreeSpace(n_labels=5)

### 3. Applications

* Frechet Mean via Sturm's Algorithm
* KNN

#### Frechet Mean via Sturm's Algorithm

In [11]:
fm = FrechetMean(
    bhv_space, sample_method="cyclic", max_iter=10000, epsilon=1e-4, window_length=10
)
fm.fit(trees, verbose=False)

fm.estimate_



[((({0, 1, 3}, {2, 4}), ({0, 1}, {2, 3, 4})), (np.float64(8.96874193767914e-05), np.float64(2.0898984294824082e-05)))]

In [10]:
fm = FrechetMean(
    bhv_space,
    sample_method="stochastic",
    max_iter=10000,
    epsilon=1e-4,
    window_length=10,
)
fm.fit(trees, verbose=False)

fm.estimate_

[((({0, 2, 4}, {1, 3}), ({0, 4}, {1, 2, 3})), (np.float64(8.014483411907475e-06), np.float64(6.467016616359967e-05)))]

In [14]:
### Silly test just changing edge weights
shorter_edges_tree = trees[0]
longer_edges_tree_lengths = gs.array([l + 1 for l in shorter_edges_tree.lengths])
longer_edges_tree = Tree(shorter_edges_tree.topology.splits, longer_edges_tree_lengths)

print(shorter_edges_tree)
print(longer_edges_tree)

fm.fit(np.array([shorter_edges_tree, longer_edges_tree]), verbose=False)
fm.estimate_

(({0, 1, 3}|{2, 4}, {0, 1}|{2, 3, 4});[1.85909642 1.20543094])
(({0, 1, 3}|{2, 4}, {0, 1}|{2, 3, 4});[2.85909642 2.20543094])


[((({0, 1, 3}, {2, 4}), ({0, 1}, {2, 3, 4})), (np.float64(2.3591670679589867), np.float64(1.7055015911542906)))]

In [13]:
### Silly test tree reflected across one axis
# look here https://plewis.github.io/applets/bhvspace/
lengths = gs.array([3, 1])
top_tree = Tree(gs.array([Split({3, 4}, {0, 1, 2}), Split({0, 1}, {2, 3, 4})]), lengths)
bottom_tree = Tree(
    gs.array([Split({2, 3}, {0, 1, 4}), Split({0, 1}, {2, 3, 4})]), lengths
)

print(top_tree)
print(bottom_tree)

fm.fit(np.array([top_tree, bottom_tree]), verbose=False)
# SHOULD BE {01|234} 1  and other edge collapses to 0
fm.estimate_



(({0, 1, 2}|{3, 4}, {0, 1}|{2, 3, 4});[3 1])
(({0, 1, 4}|{2, 3}, {0, 1}|{2, 3, 4});[3 1])


[((({0, 1, 4}, {2, 3}), ({0, 1}, {2, 3, 4})), (np.float64(0.00030003000300030005), np.float64(1.0)))]

#### KNN still illegal

In [12]:
# BAD

from geomstats.learning.knn import KNearestNeighborsClassifier
import numpy as np

print(np.array([i % 3 for i in range(10)]).reshape(-1, 1))
print(len(trees))

trees = np.array(trees).reshape(-1, 1)
knn = KNearestNeighborsClassifier(bhv_space, n_neighbors=3)
knn.fit(trees, np.array([i % 3 for i in range(10)]).reshape(-1, 1))

[[0]
 [1]
 [2]
 [0]
 [1]
 [2]
 [0]
 [1]
 [2]
 [0]]
30


TypeError: float() argument must be a string or a real number, not 'Tree'