# Approximate Bottleneck Distance

In [8]:
import numpy as np
import scipy.linalg
import scipy.stats
import sklearn.metrics
from sklearn.cluster import KMeans
import numba
import matplotlib.pyplot as plt
import ot
import umap
import seaborn as sns

## Adding more example diagrams


In [9]:
from teaspoon.MakeData.PointCloud import testSetManifolds 
from teaspoon.TDA.Distance import dgmDist_Hera

Were are going to generate 10 small examples from each of the 6 classes

In [10]:
%%time
manifoldData = testSetManifolds(numDgms = 10, numPts = 300, permute = False, seed = 0)

Generating torus clouds...
Generating annuli clouds...
Generating cube clouds...
Generating three cluster clouds...
Generating three clusters of three clusters clouds...
Generating sphere clouds...
Finished generating clouds and computing persistence.

CPU times: user 18.7 s, sys: 110 ms, total: 18.8 s
Wall time: 4.71 s


In [11]:
# We are just going to consider the 1-dimension persistence 

# In birth-death
JustDgms_death = list(manifoldData['Dgm1'])
# In birth-lifetime
JustDgms_lifetime = [np.concatenate([[X[:,0]],[X[:,1]-X[:,0]]], axis = 0).T for X in JustDgms_death]

Here is the wasserstein code... We can take a p and a q, where the q is the internal p for L_p norms. We need to work with the infinite case, but that can easily be handled.

In [12]:
def wasserstein_diagram_distance(pts0, pts1, y_axis='death', p=1, q=2):
    '''
    Compute the Persistant p-Wasserstein distance between the diagrams pts0, pts1
    
    y_axis = 'death' (default), or 'lifetime'
    
    '''
    
    if y_axis == 'lifetime':
        extra_dist0 = pts0[:, 1]
        extra_dist1 = pts1[:, 1]
    elif y_axis == 'death':    
        extra_dist0 = (pts0[:, 1] - pts0[:, 0]) * (2 **((1.0 / q) - 1))
        extra_dist1 = (pts1[:, 1] - pts1[:, 0]) * (2 **((1.0 / q) - 1))
    else:
        raise ValueError('y_axis must be \'death\' or \'lifetime\'')
        
    if np.isfinite(q):
        pairwise_dist = sklearn.metrics.pairwise_distances(pts0, pts1, metric="minkowski", p=q)
    else:
        pairwise_dist = sklearn.metrics.pairwise_distances(pts0, pts1, metric="chebyshev")
    
    all_pairs_ground_distance_a = np.hstack([pairwise_dist, extra_dist0[:, np.newaxis]])
    extra_row = np.zeros(all_pairs_ground_distance_a.shape[1])
    extra_row[:pairwise_dist.shape[1]] = extra_dist1
    all_pairs_ground_distance_a = np.vstack([all_pairs_ground_distance_a, extra_row])
  
    all_pairs_ground_distance_a = all_pairs_ground_distance_a**p
    
    n0 = pts0.shape[0]
    n1 = pts1.shape[0]
    a = np.ones(n0+1)
    a[n0]=n1
    a = a/a.sum()
    b = np.ones(n1+1)
    b[n1]=n0
    b = b/b.sum()
    
    return np.power((n0+n1)*ot.emd2(a, b, all_pairs_ground_distance_a),1.0/p)


### Modifications to do approximate bottleneck

Here we switch the metric to 'chebyshev' which is $L_\infty$.  Also, the transport cost we want to return is the max cost to move any element given the transport plan - not the total cost.  Ideally if this was always a matching this would be the highest cost of an entry in this matrix but there could be mass splitting in the returned solution in theory so we sum up the total cost to move each element and then take the max of that to fix that issue.   

Now, the optimal transport code is going to minimize total transport cost not the maximal transport cost, but in theory we can now take advantage of the limit and just raise all of the transport costs to the p-th power and find the optimal transport of that, which will basically be forced to minimize the maximal cost as a result.  Using this plan we compute the max cost under the original $L_\infty$ cost matrix without the p-th powers and take the max row sum / col sum of that. 

Lastly, because we are grouping all of the points at infinity together we actually only want to find the maximal cost of moving one of the real points in one of the diagrams (which will be equal if they move to each other), so we have to take some care to remove the infinite points (the last row/columns) when we are summing looking for the most costly move; the sum along the bottom row is total cost of all the points in the second diagram that get moved to the diagonal, and similarly for the last column. 

In [19]:
def approx_bottleneck_diagram_distance(pts0, pts1, y_axis='death', p=1):
    '''
    Compute the Persistant p-Wasserstein distance between the diagrams pts0, pts1
    
    y_axis = 'death' (default), or 'lifetime'
    
    '''
    
    if y_axis == 'lifetime':
        extra_dist0 = pts0[:, 1]
        extra_dist1 = pts1[:, 1]
    elif y_axis == 'death':    
        extra_dist0 = (pts0[:, 1]-pts0[:, 0])/2
        extra_dist1 = (pts1[:, 1]-pts1[:, 0])/2
    else:
        raise ValueError('y_axis must be \'death\' or \'lifetime\'')
        
    pairwise_dist = sklearn.metrics.pairwise_distances(pts0, pts1, metric='chebyshev')
    
    all_pairs_ground_distance_a = np.hstack([pairwise_dist, extra_dist0[:, np.newaxis]])
    extra_row = np.zeros(all_pairs_ground_distance_a.shape[1])
    extra_row[:pairwise_dist.shape[1]] = extra_dist1
    all_pairs_ground_distance_a = np.vstack([all_pairs_ground_distance_a, extra_row])
  
    all_pairs_ground_distance_ap = np.power(all_pairs_ground_distance_a,p)
    
    n0 = pts0.shape[0]
    n1 = pts1.shape[0]
    a = np.ones(n0+1)
    a[n0]=n1
    a = a/a.sum()
    b = np.ones(n1+1)
    b[n1]=n0
    b = b/b.sum()
    
    T=ot.emd(a, b, all_pairs_ground_distance_ap)
    
    return (n0+n1)*np.max([np.max(np.sum(T[:-1,:]*all_pairs_ground_distance_a[:-1,:],axis=1)),
                            np.max(np.sum(T[:,:-1]*all_pairs_ground_distance_a[:,:-1],axis=0))])




If we expect no mass splitting this is Leland's solution which induces an actual matching from the plan since it rounds things to 1. 

In [43]:
def match_bottleneck_diagram_distance(pts0, pts1, y_axis='death', p=1):
    '''
    Compute the Persistant p-Wasserstein distance between the diagrams pts0, pts1
    
    y_axis = 'death' (default), or 'lifetime'
    
    '''
    
    if y_axis == 'lifetime':
        extra_dist0 = pts0[:, 1]
        extra_dist1 = pts1[:, 1]
    elif y_axis == 'death':    
        extra_dist0 = (pts0[:, 1] - pts0[:, 0]) / 2
        extra_dist1 = (pts1[:, 1] - pts1[:, 0]) / 2
    else:
        raise ValueError('y_axis must be \'death\' or \'lifetime\'')
        
    pairwise_dist = sklearn.metrics.pairwise_distances(pts0, pts1, metric='chebyshev')
    
    transport_cost = np.hstack([pairwise_dist, extra_dist0[:, np.newaxis]])
    extra_row = np.zeros(transport_cost.shape[1])
    extra_row[:pairwise_dist.shape[1]] = extra_dist1
    transport_cost = np.vstack([transport_cost, extra_row])
  
    transport_cost_p = np.power(transport_cost, p)
    
    n0 = pts0.shape[0]
    n1 = pts1.shape[0]
    a = np.ones(n0+1)
    a[n0]=n1
    a = a/a.sum()
    b = np.ones(n1+1)
    b[n1]=n0
    b = b/b.sum()
    
    # We can just read off the max cost used in transport
    transport_plan = (n0 + n1) * ot.emd(a, b, transport_cost_p)
    return np.max(transport_cost[np.isclose(transport_plan, 1.0)])


## Now lets see how this converges as we vary p 


In [44]:
def approx_all_pairs_bottleneck_distance(diagrams, n=100, p=1):
    bott_all_pairs_dist = np.zeros((n, n))
    for i in range(n):
        for j in range(i,n):
            bott_all_pairs_dist[i,j] = approx_bottleneck_diagram_distance( 
                                                diagrams[i], diagrams[j], y_axis='death', p=p
            )
            bott_all_pairs_dist[j,i] = bott_all_pairs_dist[i,j]
    return bott_all_pairs_dist

In [45]:
def match_all_pairs_bottleneck_distance(diagrams, n=100, p=1):
    bott_all_pairs_dist = np.zeros((n, n))
    for i in range(n):
        for j in range(i,n):
            bott_all_pairs_dist[i,j] = match_bottleneck_diagram_distance( 
                                                diagrams[i], diagrams[j], y_axis='death', p=p
            )
            bott_all_pairs_dist[j,i] = bott_all_pairs_dist[i,j]
    return bott_all_pairs_dist

In [55]:
%%time
d_match=[match_all_pairs_bottleneck_distance(JustDgms_death, 60, p) for p in range(1,20)]

CPU times: user 34.8 s, sys: 150 ms, total: 34.9 s
Wall time: 35 s


In [56]:
%%time
d_approx=[approx_all_pairs_bottleneck_distance(JustDgms_death, 60, p) for p in range(1,20)]

CPU times: user 34.3 s, sys: 177 ms, total: 34.4 s
Wall time: 34.5 s


In [57]:
errors = [np.abs(d_match[p-1] - d_approx[p-1]) for p in range(1, 20)]

In [58]:
for p in range(1, 20):
    percent_correct = np.sum(np.isclose(errors[p-1], 0.0)) / 60.0**2
    #print(f"{p=}, {percent_correct=}, {np.max(errors[p-1])=}, {np.mean(errors[p-1])=}")
    print("p ",p, "%",percent_correct, "Error:  Max", np.max(errors[p-1]),  "  Mean", np.mean(errors[p-1]))

p  1 % 1.0 Error:  Max 3.8413716652030416e-14   Mean 8.90920246527433e-16
p  2 % 1.0 Error:  Max 2.853273173286652e-14   Mean 7.929035514584213e-16
p  3 % 1.0 Error:  Max 3.8413716652030416e-14   Mean 8.372026066232823e-16
p  4 % 1.0 Error:  Max 3.8413716652030416e-14   Mean 8.054706593066087e-16
p  5 % 1.0 Error:  Max 3.8413716652030416e-14   Mean 8.068141062652264e-16
p  6 % 1.0 Error:  Max 3.8413716652030416e-14   Mean 8.0800335558153495e-16
p  7 % 1.0 Error:  Max 2.853273173286652e-14   Mean 7.701324146304324e-16
p  8 % 1.0 Error:  Max 2.853273173286652e-14   Mean 7.286436114966537e-16
p  9 % 1.0 Error:  Max 3.8413716652030416e-14   Mean 8.133867807686497e-16
p  10 % 1.0 Error:  Max 3.11972669919669e-14   Mean 7.337456259865544e-16
p  11 % 1.0 Error:  Max 3.552713678800501e-14   Mean 8.076255713578778e-16
p  12 % 1.0 Error:  Max 3.8413716652030416e-14   Mean 7.36243627791961e-16
p  13 % 1.0 Error:  Max 3.552713678800501e-14   Mean 7.745193375541249e-16
p  14 % 1.0 Error:  Max 3.841

Both methods basically give the same results up to precision errors (there was no mass splitting which is little surprise). 

Now to compare to the correct solution

In [59]:
import persim

In [60]:
def persim_all_pairs_bottleneck_distance(diagrams, n=100):
    bott_all_pairs_dist = np.zeros((n, n))
    for i in range(n):
        for j in range(i,n):
            bott_all_pairs_dist[i,j] = persim.bottleneck(diagrams[i], diagrams[j])
            bott_all_pairs_dist[j,i] = bott_all_pairs_dist[i,j]
    return bott_all_pairs_dist

In [61]:
%%time
persim_distances = persim_all_pairs_bottleneck_distance(JustDgms_death, 60)

CPU times: user 4min 57s, sys: 780 ms, total: 4min 58s
Wall time: 4min 59s


In [62]:
errors = [np.abs(persim_distances - d_match[p-1]) for p in range(1, 20)]

In [69]:
for p in range(1, 20):
    percent_correct = np.sum(np.isclose(errors[p-1], 0.0)) /  60.0**2 *100
    #print(f"{p=}, {percent_correct=}, {np.max(errors[p-1])=}, {np.mean(errors[p-1])=}")
    print("p ",p, "% exact", percent_correct, "Error:  Max", np.max(errors[p-1]),  "  Mean", np.mean(errors[p-1]))

p  1 % exact 80.55555555555556 Error:  Max 0.15655291080474865   Mean 0.003993318486545525
p  2 % exact 93.83333333333333 Error:  Max 0.051955342292785645   Mean 0.0004415116210786263
p  3 % exact 96.11111111111111 Error:  Max 0.05195534229278562   Mean 0.00021736337699906707
p  4 % exact 97.33333333333334 Error:  Max 0.030719339847564475   Mean 0.00012309675829357973
p  5 % exact 98.05555555555556 Error:  Max 0.02148166298866261   Mean 7.206694533507553e-05
p  6 % exact 98.44444444444444 Error:  Max 0.02148166298866261   Mean 5.8784190980075584e-05
p  7 % exact 98.77777777777777 Error:  Max 0.0095406174659729   Mean 3.2610889111067215e-05
p  8 % exact 97.22222222222221 Error:  Max 0.0095406174659729   Mean 8.032467837207775e-05
p  9 % exact 91.55555555555556 Error:  Max 0.025418624281883254   Mean 0.0006747377839775393
p  10 % exact 84.0 Error:  Max 0.048307597637176514   Mean 0.0025294788823366846
p  11 % exact 78.66666666666666 Error:  Max 0.08409188687801361   Mean 0.00595635501564

In [70]:
errors = [np.abs(persim_distances - d_approx[p-1]) for p in range(1, 20)]

In [71]:
for p in range(1, 20):
    percent_correct = np.sum(np.isclose(errors[p-1], 0.0)) / 60.0**2 * 100
    #print(f"{p=}, {percent_correct=}, {np.max(errors[p-1])=}, {np.mean(errors[p-1])=}")
    print("p ",p, "% exact",percent_correct, "Error:  Max", np.max(errors[p-1]),  "  Mean", np.mean(errors[p-1]))

p  1 % exact 80.55555555555556 Error:  Max 0.15655291080474865   Mean 0.003993318486545525
p  2 % exact 93.83333333333333 Error:  Max 0.051955342292785645   Mean 0.0004415116210786263
p  3 % exact 96.11111111111111 Error:  Max 0.05195534229278562   Mean 0.00021736337699906707
p  4 % exact 97.33333333333334 Error:  Max 0.030719339847564475   Mean 0.00012309675829357973
p  5 % exact 98.05555555555556 Error:  Max 0.02148166298866261   Mean 7.206694533507553e-05
p  6 % exact 98.44444444444444 Error:  Max 0.02148166298866261   Mean 5.8784190980075584e-05
p  7 % exact 98.77777777777777 Error:  Max 0.0095406174659729   Mean 3.2610889111067215e-05
p  8 % exact 97.22222222222221 Error:  Max 0.0095406174659729   Mean 8.032467837207775e-05
p  9 % exact 91.55555555555556 Error:  Max 0.025418624281883254   Mean 0.0006747377839775393
p  10 % exact 84.0 Error:  Max 0.048307597637176514   Mean 0.0025294788823366846
p  11 % exact 78.66666666666666 Error:  Max 0.08409188687801361   Mean 0.00595635501564

Machine precision seems to be the issue beyond p=12 or so.  Let's look at the cost matrix 

In [72]:
dmat = sklearn.metrics.pairwise_distances(JustDgms_death[0], JustDgms_death[1], metric="chebyshev")
dmat = np.power(dmat, 12.0)
np.min(dmat), np.max(dmat), np.mean(dmat), np.median(dmat)

(5.430764128529048e-32,
 62.996890640459235,
 0.20336644990508054,
 2.693788688751766e-07)

That looks like it is the issue. 