# Run persistent homology
Running this notebook will:
- Use Ripser to get Betti bar codes from saved rates. 
- If nCells > 10, dim reduce spike counts using Isomap. 
- Threshold out low density points if thresholded is True.

**Ripster note:** required ripser package version can be installed on Ubuntu with:
```
pip install Cython
pip install ripser==0.3.2
```

## Setup

In [1]:
# General imports
import sys, os
import time, datetime
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from ripser import ripser as tda

# Set random seed and get current date
sd = int((time.time()%1)*(2**31))
np.random.seed(sd)
curr_date=datetime.datetime.now().strftime('%Y_%m_%d')+'_'

# Import shared scirpts
gen_fn_dir = os.path.abspath('.') + '/shared_scripts'
sys.path.append(gen_fn_dir)

import general_file_fns as gff
from binned_spikes_class import spike_counts
from dim_red_fns import run_dim_red
from scipy.spatial.distance import pdist
from sklearn import neighbors

# Load general params
gen_params = gff.load_pickle_file('./general_params/general_params.pkl')

# Create directory for results
save_dir = gff.return_dir(gen_params['results_dir'] + "/TDA")

# Set up parameters
session = 'Mouse28-140313'
state = 'Wake'
thresholded = False
area = 'ADn'
dt_kernel = 0.1
sigma = 0.1
d_idx = 10
rate_params = {'dt': dt_kernel, 'sigma': sigma}
print(('Session: %s, state: %s' % (session, state)))

Reading data from ./general_params/general_params.pkl...
Reading data from ./general_params/general_params.pkl...
Making ./data/analyses//TDA
Session: Mouse28-140313, state: Wake


### Load and smooth kernel rates

In [None]:
session_rates = spike_counts(session, rate_params, count_type='rate',anat_region=area)
rates_all = session_rates.get_spike_matrix(state)[0]
nCells_tot = rates_all.shape[1]
n_smooth_samples = np.floor(len(rates_all) / d_idx).astype(int)
smooth_rates = np.zeros((n_smooth_samples, nCells_tot))
for i in range(n_smooth_samples):
    si = i * d_idx
    ei = (i + 1) * d_idx
    smooth_rates[i] = np.mean(rates_all[si:ei], axis=0)

results = {'session': session, 'h0': [], 'h1': [], 'h2': []}

# if greater than 10 cells, dim reduce to 10 dims using Isomap
fit_dim = 10
dr_method = 'iso'
n_neighbors = 5
dim_red_params = {'n_neighbors': n_neighbors, 'target_dim': fit_dim}
if nCells_tot > 10:
    rates = run_dim_red(smooth_rates, params=dim_red_params, method=dr_method)
else:
    rates = smooth_rates

# threshold out outlier points with low neighborhood density
if thresholded:
    # a) find number of neighbors of each point within radius of 1st percentile of all
    # pairwise dist.
    dist = pdist(rates, 'euclidean')
    rad = np.percentile(dist, 1)
    neigh = neighbors.NearestNeighbors()
    neigh.fit(rates)
    num_nbrs = np.array(list(map(len, neigh.radius_neighbors(X=rates, radius=rad,
                        return_distance=False))))

    # b) threshold out points with low density
    thresholded_prcnt = 20
    threshold = np.percentile(num_nbrs, thresholded_prcnt)
    thresholded_rates = rates[num_nbrs > threshold]
    rates = thresholded_rates

# H0 & H1
H1_rates = rates
barcodes = tda(H1_rates, maxdim=1, coeff=2)['dgms']
results['h0'] = barcodes[0]
results['h1'] = barcodes[1]

# H2. Need to subsample points for computational tractability if 
# number of points is large (can go higher but very slow)
if len(rates) > 1500:
    idx = np.random.choice(np.arange(len(rates)), 1500, replace=False)
    H2_rates = rates[idx]
else:
    H2_rates = rates
barcodes = tda(H2_rates, maxdim=2, coeff=2)['dgms']
results['h2'] = barcodes[2]

# save
gff.save_pickle_file(results, save_dir + '%s_%s%s_ph_barcodes.p' % (session, state, ('_thresholded' * thresholded)))

# If plotting from a saved file, uncomment this and replace with appropriate file.
# results = gff.load_pickle_file(gen_params['results_dir'] + '2019_03_22_tda/Mouse28-140313_Wake_ph_barcodes.p')

if plot_barcode:
    col_list = ['r', 'g', 'm', 'c']
    h0, h1, h2 = results['h0'], results['h1'], results['h2']
    # replace the infinity bar (-1) in H0 by a really large number
    h0[~np.isfinite(h0)] = 100
    # Plot the longest barcodes only
    plot_prcnt = [99, 98, 90] # order is h0, h1, h2
    to_plot = []
    for curr_h, cutoff in zip([h0, h1, h2], plot_prcnt):
         bar_lens = curr_h[:,1] - curr_h[:,0]
         plot_h = curr_h[bar_lens > np.percentile(bar_lens, cutoff)]
         to_plot.append(plot_h)

    fig = plt.figure(figsize=(10, 8))
    gs = gridspec.GridSpec(3, 4)
    for curr_betti, curr_bar in enumerate(to_plot):
        ax = fig.add_subplot(gs[curr_betti, :])
        for i, interval in enumerate(reversed(curr_bar)):
            ax.plot([interval[0], interval[1]], [i, i], color=col_list[curr_betti],
                lw=1.5)
        # ax.set_xlim([0, xlim])
        # ax.set_xticks([0, xlim])
        ax.set_ylim([-1, len(curr_bar)])
        # ax.set_yticks([])
    plt.show()
