In [None]:
# To reload modified python modules
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np

from ripser import ripser
import functions.rips_blocks as rips_blocks

from scipy.spatial.distance import squareform, cdist

import matplotlib.pyplot as plt

from time import time
from datetime import timedelta

import sys

# Functions

In [None]:
def display_time(seconds):
    if seconds == np.round(seconds):
        seconds += 0.001
    return str(timedelta(seconds=seconds))[:-4]

In [None]:
# Computes the metric wedge of two distance matrices by pasting the
# elements i1 from dm1 and i2 from dm2
def metric_wedge(dm1, dm2, i1=0, i2=0, include_wedge=False):
    row1 = dm1[i1,:].reshape(-1,1)
    row2 = dm2[i2,:].reshape(1,-1)

    # The distance between points i in dm1 and j in dm2
    # is dm1[i,i1] + dm[i2,j]
    dists_inter = row1 + row2

    # Paste blocks into new distance matrix
    dm_wedge = np.block([[dm1, dists_inter], [dists_inter.T, dm2]])

    # Remove duplicated row/column
    i_wedge = dm1.shape[0] + i2
    dm_wedge = np.delete(dm_wedge, i_wedge, axis=0)
    dm_wedge = np.delete(dm_wedge, i_wedge, axis=1)

    # Remove wedge point (if it isn't wanted)
    if not include_wedge:
        i_wedge = i1
        dm_wedge = np.delete(dm_wedge, i_wedge, axis=0)
        dm_wedge = np.delete(dm_wedge, i_wedge, axis=1)

    return dm_wedge

# Random blocks

In [None]:
max_blocks = 15
block_sz = 20

seed = 304
rng = np.random.default_rng(seed)

Xs_all = []
dms_all = []
for n_blocks in range(1,max_blocks+1):
    # Generate subsets of 3-sphere and compute their distance matrices
    dim = 3
    X_blocks = []
    dms_blocks = []
    for i in range(n_blocks):
        # Note: we add 2 more points to account for the wedge point being deleted
        X0 = rng.random((block_sz+2,dim+1))
        norm = np.linalg.norm(X0, axis=1).reshape(-1,1)
        X0 = X0/norm

        # Reduce the size of the initial matrix
        if i==0:
            X0 = X0[:-2,:]

        # Compute distance matrix and round to 2 decimals
        scale = (i+1)**2         # Rescale to increase range of distances
        dm = scale*cdist(X0,X0)
        dm = np.round(dm,2)
        
        X_blocks.append(X0)
        dms_blocks.append(dm)

    # Paste the blocks
    # I'll do a metric wedge of the distance matrices and then remove
    # the wedge points so that the resulting metric space has a block
    # structure without being a metric wedge
    dm_big = dms_blocks[0]

    for idx in range(1,len(dms_blocks)):
        dm = dms_blocks[idx]

        i1 = dm_big.shape[0]-1
        dm_big = metric_wedge(dm_big, dm, i1=i1, include_wedge = False)

    N = dm_big.shape[0]
    print(f'{n_blocks}/{max_blocks} Num. points: {N}')

    Xs_all.append(X_blocks)
    dms_all.append(dm_big)

In [None]:
for i in range(n_blocks):
    fig, ax = plt.subplots(1,1,figsize=(3,3))
    im = ax.imshow(dms_blocks[i])
    plt.colorbar(im, ax=ax)

In [None]:
maxdim = 3
time_blocks = np.zeros((maxdim, max_blocks))
time_ripser = np.zeros((maxdim, max_blocks))
same_diagrams = np.zeros((maxdim, max_blocks), dtype=bool)

for i in range(max_blocks):
    dm = dms_all[i]
    print(' -------------------------------------- ')
    print(f' n_points: {dm.shape[0]} ({i+1} blocks)')
    print(' -------------------------------------- ')

    for k in range(1, maxdim+1):
        print(f' ---- dim: {k} ---- ')
        
        # First decompose into blocks, then compute homology
        time_start_blocks = time()
        diagrams_blocks = rips_blocks.ripser_with_blocks(dm, file_name='dm_rand', maxdim=k)
        time_end_blocks = time()

        time_blocks[k-1, i] = time_end_blocks - time_start_blocks
        print('Time with blocks:', display_time(time_end_blocks-time_start_blocks))
        sys.stdout.flush()

        # Compute homology directly
        time_start_default = time()
        diagrams_full = ripser(dm, maxdim=k, distance_matrix=True)['dgms']
        time_end_default = time()                

        time_ripser[k-1, i] = time_end_default - time_start_default
        print('Time whole:', display_time(time_end_default-time_start_default))
        sys.stdout.flush()
        # print()

        # Sort Ripser diagrams
        for idx in range(k+1):
            diagrams_full[idx] = rips_blocks.sort_diagrams(diagrams_full[idx])

        # Verify that we got the same diagrams
        check = True
        for dim in range(k+1):
            check = check and np.array_equal(diagrams_blocks[dim], diagrams_full[dim])
        same_diagrams[k-1, i-1] = check
        print('Same diagrams:', check)
        print()

In [None]:
# Save time results
np.savez('benchmark_time.npz',
         # Parameters
         maxdim=maxdim,
         max_blocks=max_blocks,
         block_sz=block_sz,
         seed=seed,
         # Data
         # Xs_all=Xs_all,
         # dms_all=dms_all,
         # Results
         time_blocks=time_blocks,
         time_ripser=time_ripser,
         same_diagrams=same_diagrams)

In [None]:
# Load results
data = np.load('benchmark_time.npz')
maxdim = data['maxdim']
max_blocks = data['max_blocks']
time_blocks = data['time_blocks']
time_ripser = data['time_ripser']
same_diagrams = data['same_diagrams']

# Display global results
print('Total time (blocks):', np.sum(time_blocks))
print('Total time (Ripser):', np.sum(time_ripser))

total_time = np.sum(time_blocks) + np.sum(time_ripser)
print('Total time -- all:', display_time(total_time))
print('All diagrams agree:', np.all(same_diagrams))

In [None]:
fig, axes = plt.subplots(1,3, figsize=(18,5), sharex=True, sharey=True)
block_range = np.arange(1, max_blocks+1)

for k in range(maxdim):
    axes[k].plot(block_range, time_blocks[k,:], label='Blocks', marker='.', markersize=10)
    axes[k].plot(block_range, time_ripser[k,:], label='Ripser', marker='.', markersize=10)
    axes[k].set_title(f'Dim {k+1}')
    axes[k].legend()
    axes[k].set_xlabel('Number of blocks')
    axes[k].set_ylabel('Time (s)')
    axes[k].set_yscale('log')
    axes[k].grid()

# Set x-ticks
xticks = np.arange(0, max_blocks+1, 3)
axes[0].set_xticks(xticks)

# Change font size
plt.rcParams.update({'font.size': 14})

# Save figure
latex_dir = '/home/mrgomez/Documents/OSU/My-Papers/Split-Decompositions/figures'
plt.savefig(f'{latex_dir}/BloDec_vs_Ripser.pdf', dpi=300, bbox_inches='tight')

In [None]:
plt.plot(time_blocks.T)
plt.yscale('log')
plt.legend([1,2,3])