In [4]:
import numpy as np
import gudhi as gd
import matplotlib.pyplot as plt

In [5]:
def create_persistence_complex(dataset, complex_type='alpha', max_dim=3):
    '''
    Given a dataset, constructs a persistence complex out of that dataset
    '''
    if complex_type == 'alpha':
        complex = gd.AlphaComplex(points=dataset)
    elif complex_type == 'rips':
        complex = gd.RipsComplex(points=dataset)
    else:
        raise TypeError("Complex type must be \'rips\' or \'alpha\'")
    
    pc = []
    sd = {}

    stree = complex.create_simplex_tree()
    stree.prune_above_dimension(max_dim)

    for i, (vertices, squared_circumradii) in enumerate(stree.get_filtration()):
        deg = np.sqrt(squared_circumradii)
        s = Simplex(vertices, deg)
        pc.append(s)
        sd[tuple(vertices)] = [False, i]

    return pc, sd

In [6]:
class Simplex:
    def __init__(self, vertices, degree):
        self.vertices = vertices
        self.dim = len(vertices) - 1
        self.deg = degree
        self.marked = False

    def __str__(self):
        return f"Vertices: {self.vertices}, degree: {self.deg} at index {self.idx}"

    def compute_reduced_boundary_chain(self, simplex_dictionary):
        boundary_chain = set()

        if self.dim == 0:
            return boundary_chain

        for i, _ in enumerate(self.vertices):
            face = tuple(self.vertices[:i] + self.vertices[i+1:])
            #only add unmarked faces
            if not simplex_dictionary[face][0]:
                boundary_chain.add(face)
        
        return boundary_chain
        

In [7]:
def compute_intervals(persistence_complex, simplex_dictionary, max_dim=3):
    '''
    Computes persistence homology intervals given a dataset
    '''

    #initialize data structures
    interval_sets = [[]]*max_dim
    num_simplices = len(persistence_complex)
    T = [None] * num_simplices


    for (j, simplex_j) in enumerate(persistence_complex):
        d = remove_pivot_rows(simplex_j, simplex_dictionary, T)
        if len(d) == 0:
            simplex_j.marked = True
        else:
            i = max_index(d, simplex_dictionary)
            simplex_i = persistence_complex[i]
            k = simplex_i.dim
            T[i] = d
            interval_sets[k].append((simplex_i.deg, simplex_j.deg))

    for (j, simplex) in enumerate(persistence_complex):
        if simplex.marked and T[j] is None:
            k = simplex.dim
            interval_sets[k]=interval_sets[k].append((simplex.deg, float('inf')))

    return interval_sets

In [8]:
def remove_pivot_rows(simplex, simplex_dictionary, T):

    d = simplex.compute_reduced_boundary_chain(simplex_dictionary)

    while len(d) > 0:
        i = max_index(d, simplex_dictionary)
        if T[i] is None:
            break
        d = d ^ T[i]

    return d

In [9]:
def max_index(boundary_chain, simplex_dictionary):
    max_idx = 0 
    for face in boundary_chain:
        deg = simplex_dictionary[face][1]
        if deg > max_idx:
            max_idx = deg
    return max_idx

In [10]:
def display_intervals(persistence_intervals):
    '''
    Creates barcode display for given persistence homology intervals for a single dimension
    '''
    num_intervals = len(persistence_intervals)

    fig, ax = plt.subplots(figsize=(6, num_intervals * 0.2))

    ax.barh(np.arange(num_intervals), [b-a for (a, b) in persistence_intervals], left = [a for (a, _) in persistence_intervals], height = 0.7, color='black')
    ax.get_yaxis().set_visible(False)

    plt.show()