# Applications - Fast cohomological cycle matching

In this notebook, we reproduce the examples presented in Section 3 of the paper ["Fast Topological Signal Identification and Persistent Cohomological Cycle Matching" (García-Redondo, Monod, and Song 2022)](https://arxiv.org/abs/2209.15446) whenever the data used is available for public use. 

This notebook is intended to be used alongside a High Performance Computer Cluster (HPC) and it follows the next structure:
- First we **generate the relevant data**, for which we use the auxiliar script `choose_data.py`. In this step we select the type of application that we intend:
    - `matching` for direct application of cycle matching, such as the tracking applications,
    - `prevalence` for prevalence application.
- After that there is an indication to run the script to **send the jobs to the HPC**. For this step one needs:
    1. Scripts to send to the HPC. In the folder of the repository there are examples provided for two workload managers: SLURM and OpenPBS. The files `submit_prevalence.pbs` and `submit_matching.pbs` correspond to the scripts to run for OpenPBS for the prevalence and tracking applications, respectively. The files `exec_prevalence.sh` and `exec_matching.sh` correspond to the scripts to run for SLURM for the prevalence and tracking applications, respectively.
    2. The auxiliar python scripts `apply_matching.py` and `apply_prevalent.py`, which are called from the workload manager scripts in order to do matching and prevalence applications, respectively.
    3. A compiled version of the modified C++ files in the folders `ripser-tight-representative-cycles` and `simple-ripser-image`. These are due to [1] and [2], respectively, and the version included here is only sligthly altered to retrieve the indices of the simplices associated to the persistence pairs. These indices are obtained after taking a lexicographic refinement of the Vietoris-Rips filtration as explained in [1].
- Finally, we **retrieve and process the data** from the previous computations.

### References
[1] Bauer, Ulrich. 2021. ‘Ripser: Efficient Computation of Vietoris-Rips Persistence Barcodes’. Journal of Applied and Computational Topology 5 (3): 391–423. https://doi.org/10.1007/s41468-021-00071-5.

[2] Bauer, Ulrich, and Maximilian Schmahl. 2022. ‘Efficient Computation of Image Persistence’. ArXiv:2201.04170 [Cs, Math], January. http://arxiv.org/abs/2201.04170.


## Matching - tracking applications

### Choose the type of application

In [None]:
application = 'matching' #'prevalence'

### Generate the data

In [None]:
from utils_PH import *
from utils_plot import *
plt.rcParams['savefig.facecolor']='white' # set background to white in savefig
import pickle

if application == 'prevalence': 
    from choose_data import DATASET, generate_data_resamplings, N_ref, N, N_resamp, noise_scale

    # temp folder
    temp_folder = '{}_temp/'.format(DATASET)
    if not os.path.exists(temp_folder) :
        os.mkdir(temp_folder)
        print('created temp folder for dataset')
    else :
        raise Exception('Are you sure you want to write on previous temp folder? PLEASE DONT UNACTIVATE THIS MESSAGE')

    # generate data + resamplings
    full_data, X, list_Y = generate_data_resamplings(dataset = DATASET)
    print('dataset :', DATASET)
    print('params : ', N_ref, N, N_resamp, noise_scale)

    if os.path.exists('{}_temp/X_{}.pkl'.format(DATASET,N_ref)) :
        raise Exception('data already generated')

    # save data
    pickle.dump(full_data, open('{}_temp/full_data.pkl'.format(DATASET), 'wb'))
    pickle.dump(X, open('{}_temp/X_{}.pkl'.format(DATASET,N_ref), 'wb'))
    pickle.dump(list_Y, open('{}_temp/list_Y_samp{}_{}.pkl'.format(DATASET, N_resamp, N), 'wb'))

    print('finished generating data and resamplings')

    # compute PH of Xref
    out_X = compute_bars_tightreps(X,filename = '{}_temp/ldm_X'.format(DATASET))
    bars_X, reps_X, tight_reps_X, indices_X = extract_bars_reps_indices(out_X, only_dim_1 = True)    

    res_X = bars_X, reps_X, tight_reps_X, indices_X
    # save it
    pickle.dump(res_X, open('{}_temp/res_X_{}.pkl'.format(DATASET,N_ref), 'wb'))

    # remove unused file
    os.remove('{}_temp/ldm_X.lower_distance_matrix'.format(DATASET))

    print('finished computing PH of ref data')

if application == 'matching':
    from choose_data import DATASET, generate_data_matching, N, noise_scale
    
    # temp folder
    temp_folder = '{}_temp/'.format(DATASET)
    if not os.path.exists(temp_folder) :
        os.mkdir(temp_folder)
        print('created temp folder for dataset')
    else :
        raise Exception('Are you sure you want to write on previous temp folder? PLEASE DONT UNACTIVATE THIS MESSAGE')
        
    # generate data
    full_data, list_X, list_indices = generate_data_matching(dataset = DATASET)
    print('dataset :', DATASET)
    print('number of points of the samples : ', N)

    if os.path.exists('{}_temp/list_X{}.pkl'.format(DATASET,N)) :
        raise Exception('data already generated')

    # save data
    pickle.dump(full_data, open('{}_temp/full_data.pkl'.format(DATASET), 'wb'))
    pickle.dump(list_X, open('{}_temp/list_X.pkl'.format(DATASET), 'wb'))

    print('finished generating data and resamplings')

### ------ Run the jobs in the HPC now ------

### Collect data and visualise

#### Prevalence

In [None]:
## COLLECT RESULTS - prevalence

res_X = pickle.load( open( '{}_temp/res_X_{}.pkl'.format(DATASET,N_ref),'rb'))
bars_X, reps_X, tight_reps_X, indices_X = res_X
bars_reps_X = bars_X, reps_X, tight_reps_X

list_matched_X_Y = []
list_affinity_X_Y = []
list_bars_reps_Y = []

for y in range(N_resamp) :

    result_i = pickle.load(open('{}_temp/res_match_{}.pkl'.format(DATASET, y), 'rb'))
    matched_X_Y, affinity_X_Y, bars_reps_Y = result_i

    bars_Y, reps_Y, tight_reps_Y = bars_reps_Y

    list_matched_X_Y += [matched_X_Y]
    list_affinity_X_Y += [affinity_X_Y]
    list_bars_reps_Y += [bars_reps_Y]
    
# compute prevalence scores for all bars of Xref

dim = 1
PH1 = bars_X[dim]
PH1 = np.array(PH1)
ph1 = len(PH1)

scores = np.zeros(ph1)

for s in range(N_resamp) :
    for match, aff in zip(list_matched_X_Y[s],list_affinity_X_Y[s]) :
        a,b = match
        scores[a] += aff
scores /= N_resamp

# sort PH1 and scores BUT AFTER COMPUTING scores above
arg = argsort(PH1[:,0], option = 'asc')
PH1_sorted = PH1[arg]
scores_sorted = scores[arg]

In [None]:
# plot AUGMENTED BARCODES of X : with thickness / color = prevalence, length = persistence

figpath1 = '{}_temp/{}_barcode_Nref{}_samp{}_N{}.png'.format(DATASET,DATASET, N_ref, N_resamp, N)
figpath2 = '{}_temp/{}_augmented_barcode_Nref{}_samp{}_N{}.png'.format(DATASET,DATASET, N_ref, N_resamp, N)

plot_bars_PH1(PH1_sorted, scores = None, diagonal = False, delta_y = .2, delta_y_prev = 1, figpath = figpath1)
plot_bars_PH1(PH1_sorted, scores = scores_sorted, diagonal = False, delta_y = .2, delta_y_prev = 1, figpath = figpath2)

# scatter plot of prevalence scores
plt.figure()
plt.scatter(np.arange(ph1), scores_sorted)
plt.xticks([]) #plt.xticks(np.arange(len(scores)))
plt.title('Prevalence scores of cycles in X (sorted by birth time)')
plt.tight_layout()
plt.savefig('{}_temp/{}_scores_Nref{}_samp{}_N{}.png'.format(DATASET,DATASET, N_ref, N_resamp, N), dpi = 300)
plt.show()

In [None]:
# show cycreps on X
XT = X[:,[1,0]]
XTt = np.zeros(X.shape)
XTt[:,0] = X[:,1]
XTt[:,1] = -X[:,0]#[::-1]

ax = plot_cycreps(XTt, tight_reps_X[dim], ax = None, zoom_factor = 1.5, return_ax = True)
plt.axis('off')
plt.show()

# show stained cycreps on X
ax = plot_cycreps_prevalence_2D(XTt, tight_reps_X[dim], scores, plot_points = True,
                           zoom_factor = 1.5, maxi_colorbar = None, return_ax = True)
plt.axis('off')
plt.tight_layout()
plt.savefig('{}_temp/{}_stained_cyclesX_Nref{}_samp{}_N{}.png'.format(DATASET, DATASET,N_ref, N_resamp, N), dpi = 300)
plt.show()

In [None]:
if 'actin' in DATASET :
    
    # plot prevalent cycles + img + pts ?

    ax = plot_cycreps_prevalence_2D(XT, tight_reps_X[dim], scores, plot_points = False,
                                    maxi_colorbar = None, zoom_factor = 2, return_ax = True)

    ## merged resampling point clouds
    #xx, yy = whole_Y.T
    #ax.scatter(xx, yy, c = '#729dcf', s = 5, alpha = 1)

    ##  original img
    img, otsu = full_data
    ax.imshow(img, cmap = 'gray')
    plt.tight_layout()
    ## uncomment below to save
    plt.savefig('{}_temp/{}_result_Nref{}_samp{}_N{}.png'.format(DATASET,DATASET, N_ref, N_resamp, N), dpi = 300)

    plt.show()

In [None]:
if DATASET == 'cosmic_web' :
    # plot prevalent cycles + galaxies (optional : + filaments (Duque et al., 2022))

    ax = plot_cycreps_prevalence_2D(X, tight_reps_X[dim], scores, plot_points = False,
                                    maxi_colorbar = None, zoom_factor = 2, return_ax = True)

    ## merged resampling point clouds
    #xx, yy = whole_Y.T
    #ax.scatter(xx, yy, c = '#729dcf', s = 5, alpha = 1)

    ##  original galaxy distribution
    x,y = full_data.T
    ax.scatter(x, y, c = '#729dcf', s = 5, alpha = 1)

    ## filaments (load them with previous section)
    # ax.scatter(xf,yf,s = 5, c = 'green',alpha = .2)

    ## uncomment below to save
    #plt.savefig('{}_temp/{}_result_Nref{}_samp{}_N{}.png'.format(DATASET,DATASET,N_ref, N_resamp, N), dpi = 300)

    plt.show()

When satisfied with results, **RENAME TEMP FOLDER**

#### Matching - tracking

In [None]:
## COLLECT RESULTS - matching

from utils_PH import *
from utils_plot import *
plt.rcParams['savefig.facecolor']='white' # set background to white in savefig
import pickle


from choose_data import DATASET, generate_data_matching, N, noise_scale

list_X = pickle.load( open( '{}_temp/list_X.pkl'.format(DATASET),'rb'))
full_data = pickle.load(open('{}_temp/full_data.pkl'.format(DATASET), 'rb'))
u, thres_mean = full_data

list_matched_X_Y = []
list_affinity_X_Y = []
list_bars_reps_X = []
list_bars_reps_Y = []

for y in range(len(list_X)-1) :
    result_i = pickle.load(open('{}_temp/res_match_{}.pkl'.format(DATASET, y), 'rb'))
    
    matched_X_Y, affinity_X_Y, bars_reps_X, bars_reps_Y = result_i

    list_matched_X_Y += [matched_X_Y]
    list_affinity_X_Y += [affinity_X_Y]
    list_bars_reps_X += [bars_reps_X]

In [None]:
# track all the cycles in the different slices
from utils_PH import*

initial_slice = 0 #always for our applications
N_slices = 10
step = 1

# for lateral line: N_slices = , step = 3
# for heartbeat and embryogenesis: = N_slices = 10, step = 1

cycles_tracked = {}
affinities_tracked = {}
list_indices =[i for i in range(step, N_slices + initial_slice -  step, step)]

from copy import deepcopy
list_matched_copy = deepcopy(list_matched_X_Y)

for i in range(initial_slice, N_slices + initial_slice - step, step):
    cycles_tracked[i] = []
    affinities_tracked[i] = []
    if i > initial_slice:
        list_indices.remove(i)
    for j, match in enumerate(list_matched_copy[i]):
        cycle = match[0]
        tracked_cycle, tracked_affinity = \
            track_cycles_from_slice(list_matched_X_Y, list_affinity_X_Y, cycle, list_indices, initial_slice = i)
        length_track = len(tracked_cycle)
        if length_track > 1 :
            cycles_tracked[i] += [tracked_cycle]
            affinities_tracked[i] += [tracked_affinity]
        if length_track > 1:
            for k, j in zip(range(i, N_slices + initial_slice -  step, step),  range(length_track)):
                list_matched_copy[k].remove(tracked_cycle[j])

In [None]:
# Make a single list with all the information for the plots
list_bars = {}
list_reps = {}
list_tight_reps = {}

for i in range(initial_slice, initial_slice + N_slices  -step, step):
    list_bars[i], list_reps[i], list_tight_reps[i] = list_bars_reps_X[i]

out_last_slice = compute_bars_tightreps(list_X[N_slices - 1])
list_bars_reps_last_slice = extract_bars_reps(out_last_slice)
list_bars[N_slices - 1], list_reps[N_slices - 1], list_tight_reps[N_slices - 1] = list_bars_reps_last_slice

In [None]:
# Count the number of tunnels and generate colors accordingly
n = 0
for i in range(initial_slice, initial_slice + N_slices-step,step):
    n += len(cycles_tracked[i])

import random
no_of_colors= n
colors =["#"+''.join([random.choice('0123456789ABCDEF') for i in range(6)]) 
       for j in range(no_of_colors)]
for j in range(no_of_colors):
    plt.scatter(random.randint(0,10),random.randint(0,10),c=colors[j],s=200)
plt.show()

In [None]:
#### PLOT FOR THE LATERAL LINE (FOR INSTANCE)

tracked = cycles_tracked[0][0]
length_track = len(tracked)
u, thres_otsu = full_data
images = u
list_cycreps = list_tight_reps
dim = 1


# some pre-selected colors
#colors = ['indianred', 'maroon', 'salmon', 'orangered', 'sandybrown', 'bisque', 'goldenrod', 'gold', 'yellow', 'olivedrab',
#        'greenyellow', 'darkseagreen', 'palegreen', 'forestgreen', 'turquoise', 'teal', 'slategray', 'lightblue', 'steelblue', 
#        'blue', 'mediumslateblue', 'blueviolet', 'violet', 'purple', 'magenta', 'hotpink', 'crimson', 'palevioletred','skyblue', 
#        'blue', 'pink', 'red', 'green', 'yellow', 'cyan', 'olive' ]


fig, ax = plt.subplots(3, 5, figsize = (20,12))

row = -1
c = 0
for j in range(len(list_X)-1):
    #print(j)
    #print(j, j%5)
    if j % 5 == 0:
        row += 1
    ax[row,j%5].imshow(images[j*3, :, :], cmap = 'gray', origin = 'lower')
    ax[row,j%5].set_aspect('equal')
    num_cycles = len(cycles_tracked[j])
    for k in range(num_cycles):
        tracked = cycles_tracked[j][k]
        for i, match in enumerate(tracked):
            index_cycle = match[0]
            cycrep = list_cycreps[(j+i)][dim][index_cycle]
            xx = list_X[3*(j+i)][:,0]
            yy = list_X[3*(j+i)][:,1]
            if len(cycrep) > 0 :
                cycrep = np.array(cycrep)
                ax[((j+i)-(j+i)%5)//5,(j+i)%5].plot(yy[cycrep.T], xx[cycrep.T], c = colors[c], alpha = 0.7, linewidth = 2)
            if j + i == 13:
                #print(match)
                index_cycle = match[1]
                cycrep = list_cycreps[(j+i+1)][dim][index_cycle]
                #print(cycrep)
                xx = list_X[3*(j+i+1)][:,0]
                yy = list_X[3*(j+i+1)][:,1]
                cycrep = np.array(cycrep)
                ax[((j+i)-(j+i)%5)//5,(j+i)%5+1].plot(yy[cycrep.T], xx[cycrep.T], c = colors[c], alpha = 0.7, linewidth = 2)
                    
        c += 1
ax[2,4].imshow(images[14*3, :, :], cmap = 'gray', origin = 'lower')

for i in range(3):
    for j in range(5):
        ax[i,j].tick_params(bottom = False, left = False, labelbottom = False, labelleft = False)

plt.tight_layout()
plt.show()