In [1]:
%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np
from numpy.random import default_rng

import ripser
import cechmate as cm
import persim
from persim import plot_diagrams

import networkx as nx
from networkx.algorithms import bipartite
from matplotlib.widgets import Slider, RadioButtons

In [2]:
rng = default_rng(0)
rng.random()

0.6369616873214543

In [3]:
def randomPD(card):
    pd = np.zeros((card,2))
    pd[:,0]=rng.random(card)
    pd[:,1]=pd[:,0] + rng.random(card)
    return pd

In [4]:
rng = default_rng(3)
X0 = randomPD(4)
Y0 = randomPD(5)

f,ax = plt.subplots(1,1, figsize=(8,8))
persim.plot_diagrams([X0,Y0],ax=ax, labels = ['$X_0$','$Y_0$'])
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [5]:
def DiagProj(x):
    return[(x[0]+x[1])/2]*2

In [6]:
X0prime = np.array([DiagProj(y) for y in Y0])
Y0prime = np.array([DiagProj(x) for x in X0])

f,ax = plt.subplots(1,1, figsize=(8,8))
persim.plot_diagrams([X0,Y0,X0prime,Y0prime],ax=ax, labels = ['$X_0$','$Y_0$', '$X_0`$','$Y_0`$'])
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [40]:
U = np.append(X0,X0prime, axis=0)
V = np.append(Y0,Y0prime, axis=0)
U = [tuple(u) for u in U]
V = [tuple(v) for v in V]

B = nx.Graph()
B.add_nodes_from(U, bipartite=0)
B.add_nodes_from(V, bipartite=1)

for u in U:
    for v in V:
        if u in X0 and v in Y0:
            B.add_edge(u,v, weight= max(np.abs(u[0]-v[0]), np.abs(u[1]-v[1])))
        elif u in X0prime and v in Y0prime:
            B.add_edge(u,v, weight=0)
            
for u in X0:
    B.add_edge(tuple(u),tuple(DiagProj(u)), weight=(u[1]-u[0])/2)
for v in Y0:
    B.add_edge(tuple(v),tuple(DiagProj(v)), weight=(v[1]-v[0])/2)

In [43]:
nx.bipartite.sets(B)
#U[0][0][1]


({(0.08564916714362436, 0.17977780938402355),
  (0.2368105065960997, 0.6699374468325735),
  (0.4825909135674835, 0.4825909135674835),
  (0.5821620360643678, 0.7419009507014463),
  (0.6588407644957595, 0.6588407644957595),
  (0.7549016239540904, 0.7549016239540904),
  (0.8012744652063969, 1.280325763347231),
  (0.8693618179137114, 0.8693618179137114),
  (1.027976437128285, 1.027976437128285)},
 {(0.11367201992140341, 0.8515098072135636),
  (0.13271348826382395, 0.13271348826382395),
  (0.39122819049566204, 1.3474954453317607),
  (0.4306280204141778, 1.079175227494003),
  (0.4533739767143366, 0.4533739767143366),
  (0.5167401826213637, 0.8009413463701551),
  (0.662031493382907, 0.662031493382907),
  (0.7345771514092145, 1.3213757228473553),
  (1.040800114276814, 1.040800114276814)})

In [44]:
def plotBipartite(G, ax):
    color_map = []
    for node in G.nodes:
        color_map.append(G.nodes[node]['bipartite'])
    
    pos = dict()
    pos.update( (n, (1, i)) for i, n in enumerate(U) ) # put nodes from U at x=1
    pos.update( (n, (2, i)) for i, n in enumerate(V) ) # put nodes from V at x=2
    nx.draw(G, pos=pos, node_color=color_map, with_labels=False,ax=ax)

    #labels = nx.get_edge_attributes(G,'weight')
    #nx.draw_networkx_edge_labels(G,pos,edge_labels=labels)
    

In [45]:
f,ax = plt.subplots(1,1,figsize=(8,8))
plotBipartite(B, ax)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [46]:
def getGeps(G, eps):
    G_eps = nx.Graph(G)
    #G_eps.add_nodes_from(G.nodes)
    for edge in G.edges:
        if G.edges[edge]['weight']>eps:
            G_eps.remove_edge(edge[0],edge[1])
            
    return G_eps

In [47]:

fig, ax = plt.subplots(1,1,figsize=(8,8))
plt.subplots_adjust(left=0.25, bottom=0.25)
axeps = plt.axes([0.25, 0.1, 0.65, 0.03])
seps = Slider(axeps, 'eps', 0.0, 1.5, valinit=0, valstep=0.05)

def updateGraph(val):
    ax.clear()
    G_eps=getGeps(B,val)
    plotBipartite(G_eps, ax)
    fig.canvas.draw_idle()
    
seps.on_changed(updateGraph)

#G_eps = getGeps(B, 0.5)
#plotBipartite(G_eps, ax)
#G_eps.nodes[U[0]]
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [48]:
f, ax =plt.subplots()
bn_matching, (matchidx, D) = persim.bottleneck(X0, Y0, matching=True) 
persim.wasserstein_matching(X0, Y0, matchidx, D, ax=ax)
plt.show()


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [55]:
f,ax = plt.subplots()
plt.subplots_adjust(left=0.25, bottom=0.25)
axeps = plt.axes([0.25, 0.1, 0.65, 0.03])
seps = Slider(axeps, 'eps', 0.0, 1.5, valinit=0, valstep=0.05)

def update(val):
    ax.clear()
    G_eps=getGeps(B,val)
    matching = nx.bipartite.hopcroft_karp_matching(G_eps, top_nodes=U)
    persim.plot_diagrams([X0,Y0],ax=ax, labels = ['$X_0$','$Y_0$'])
    for u in X0:
        ax.plot((pair[0][0], pair[1][0]),(pair[0][1],pair[1][1]), color='green')

    fig.canvas.draw_idle()
    
seps.on_changed(update)

#ax.plot((0.2,0.3),(0.4,0.5))
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [77]:
G_eps=getGeps(B,0.5)
matching = nx.bipartite.hopcroft_karp_matching(G_eps, top_nodes=U)
matchidx

#matching.update(dict([reversed(i) for i in matching.items()]))
matching[U[0]]

(0.13271348826382395, 0.13271348826382395)