## Sturm's Algorithm Mean BHV Space

In this notebook we will walk through mean calculations in the BHV tree space. BHV space is a geodesic metric space without an exp, and therefore cannot use gradient-descent-based methods. non-positively curved

This tutorial will be structured as follows:
1. Imports
2. Intro to BHV Space
3. Applications
    * Mean
    * KNN


### 1. Imports

In [19]:
import numpy as np

import geomstats.backend as gs
from geomstats.geometry.stratified.bhv_space import (
    Tree,
    TreeSpace,
    generate_random_tree,
)
from geomstats.geometry.stratified.trees import Split
from geomstats.learning.frechet_mean import FrechetMean

gs.random.seed(666)

### 2. BHV Space

BHV space is the space of labelled trees. It was designed with phylogenetic trees in mind. BHV space is a stratified space, sections of Euclidean space connected by lower-dimensional manifolds. As such, its geodesics are defined by XXXX.


For more details on this space, please see the original paper -- here --
For a simple visualisation to gain intuition about the non-Euclidean nature of this space, see this link -- here -- 

In [9]:
trees = np.array(
    [generate_random_tree(5, only_internal_edges=True, p_keep=1) for i in range(30)]
)

In [10]:
bhv_space = TreeSpace(n_labels=5)

### 3. Applications

* Frechet Mean via Sturm's Algorithm
* KNN

#### Frechet Mean via Sturm's Algorithm

In [32]:
fm = FrechetMean(bhv_space, max_iter=10000, epsilon=1e-3, window_length=10)
fm.fit(trees, verbose=False)

fm.estimate_

[((({0, 2, 4}, {1, 3}), ({0, 2}, {1, 3, 4})), (np.float64(0.0002392963079847288), np.float64(0.0003862849201304039)))]

In [None]:
### Silly test just changing edge weights
shorter_edges_tree = trees[0]
longer_edges_tree_lengths = gs.array([l + 1 for l in shorter_edges_tree.lengths])
longer_edges_tree = Tree(shorter_edges_tree.topology.splits, longer_edges_tree_lengths)

print(shorter_edges_tree)
print(longer_edges_tree)

fm.fit(np.array([shorter_edges_tree, longer_edges_tree]), verbose=False)
fm.estimate_

(({0, 2, 4}|{1, 3}, {0, 2}|{1, 3, 4});[5.12387921 6.0788277 ])
(({0, 2, 4}|{1, 3}, {0, 2}|{1, 3, 4});[6.12387921 7.0788277 ])


[((({0, 2, 4}, {1, 3}), ({0, 2}, {1, 3, 4})), (np.float64(5.637884813998665), np.float64(6.592833301645971)))]

In [33]:
### Silly test tree reflected across one axis
# look here https://plewis.github.io/applets/bhvspace/
lengths = gs.array([3, 1])
top_tree = Tree(gs.array([Split({3, 4}, {0, 1, 2}), Split({0, 1}, {2, 3, 4})]), lengths)
bottom_tree = Tree(
    gs.array([Split({2, 3}, {0, 1, 4}), Split({0, 1}, {2, 3, 4})]), lengths
)

print(top_tree)
print(bottom_tree)

fm.fit(np.array([top_tree, bottom_tree]), verbose=False)
# SHOULD BE {01|234} 1  and other edge collapses to 0
fm.estimate_

(({0, 1, 2}|{3, 4}, {0, 1}|{2, 3, 4});[3 1])
(({0, 1, 4}|{2, 3}, {0, 1}|{2, 3, 4});[3 1])


[((({0, 1, 2}, {3, 4}), ({0, 1}, {2, 3, 4})), (np.float64(0.031999999999999945), np.float64(1.0)))]

#### KNN still illegal

In [12]:
# BAD

from geomstats.learning.knn import KNearestNeighborsClassifier
import numpy as np

print(np.array([i % 3 for i in range(10)]).reshape(-1, 1))
print(len(trees))

trees = np.array(trees).reshape(-1, 1)
knn = KNearestNeighborsClassifier(bhv_space, n_neighbors=3)
knn.fit(trees, np.array([i % 3 for i in range(10)]).reshape(-1, 1))

[[0]
 [1]
 [2]
 [0]
 [1]
 [2]
 [0]
 [1]
 [2]
 [0]]
30


TypeError: float() argument must be a string or a real number, not 'Tree'