In [1]:
import numpy as np
import matplotlib.pyplot as plt
from ripser import ripser as tda
import matplotlib.gridspec as gridspec
from persim import plot_diagrams
import os


def run_ripser(data,model_dir=''):
    H1_rates = data
    results = {'h0': [], 'h1': [], 'h2': []}
    plot_barcode = True

    barcodes = tda(H1_rates, maxdim=2, coeff=2)['dgms']
    diagram = [barcodes[0],barcodes[1],barcodes[2]]

    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111)
    plot_diagrams(diagram,ax=ax)
    if model_dir!='':
        fig.savefig(model_dir+'_persistent.png')
    plt.close(fig) 

    results['h0'] = barcodes[0]
    results['h1'] = barcodes[1]
    results['h2'] = barcodes[2]
    h0_longest = None
    h1_longest = None
    h2_longest = None

    if plot_barcode:
        col_list = ['r', 'g', 'm', 'c']
        h0, h1, h2 = results['h0'], results['h1'], results['h2']
        h0[~np.isfinite(h0)] = 100
        plot_prcnt = [50, 0, 0] # order is h0, h1, h2
#         to_plot = []
        bar_lens_list = []
#         idx = 0
#         for curr_h, cutoff in zip([h0, h1, h2], plot_prcnt):
#             bar_lens = curr_h[:,1] - curr_h[:,0]
#             if len(bar_lens)==0:

#                 idx += 1 
#                 continue
#             plot_h = curr_h[bar_lens >= np.percentile(bar_lens, cutoff)]
#             to_plot.append(plot_h)
#             bar_lens.sort()
#             if idx == 0:
#                 h0_longest = bar_lens[-2]
#             elif idx == 1:
#                 h1_longest = bar_lens[-1]
#             elif idx == 2:
#                 h2_longest = bar_lens[-1]
#             idx += 1 
        for curr_h in [h0, h1, h2]:
            bar_lens = curr_h[:,1] - curr_h[:,0]
            bar_lens_list.append(bar_lens)
            
        to_plot = [h0, h1, h2]

        fig = plt.figure(figsize=(10, 8))
        gs = gridspec.GridSpec(3, 4)
        for curr_betti, curr_bar in enumerate(to_plot):
            ax = fig.add_subplot(gs[curr_betti, :])
            for i, interval in enumerate(reversed(curr_bar)):
                ax.plot([interval[0], interval[1]], [i, i], color=col_list[curr_betti],
                    lw=1.5)
            # ax.set_xlim([0, xlim])
            # ax.set_xticks([0, xlim])
            ax.set_ylim([-1, len(curr_bar)])
            # ax.set_yticks([])
#         plt.show()
        if model_dir!='':
            fig.savefig(model_dir+'.png')
        plt.close(fig) 

#     return (h0_longest,h1_longest,h2_longest)
    return to_plot, bar_lens_list

In [2]:
n_loc=128
stim_loc_shape = n_loc,n_loc,1
stim_loc_size = np.prod(stim_loc_shape)
ind_stim_loc1, ind_stim_loc2, ind_repeat = np.unravel_index(range(stim_loc_size),stim_loc_shape)

In [3]:
n_rnn=256
bs=512
for acc in [60]:      
    for seed in range(12,21):
#     for seed in [2,3,5,6,7,11,13,14,16,17,18,20]:
        model_folder = f'/Volumes/Seagate Backup Plus Drive/fyp/results/2_stim_batch_size_{bs}/n_hidden_{n_rnn}/2_stim_batch_size_{bs}_n_hidden_{n_rnn}_acc_{acc}_seed_{seed}_with_noise/'
        print('model folder: ' + model_folder)
        isomap_folder = model_folder+'/isomap'
        if os.path.exists(isomap_folder):
            with open(isomap_folder+f'/proj.npy', 'rb') as f:
                proj=np.load(f)

            betti_folder = isomap_folder+'/betti'
            if not os.path.exists(betti_folder):
                os.makedirs(betti_folder)
                
            for num in range(128):
                indices = ind_stim_loc1==num
                label_plot = ind_stim_loc2[indices]
                proj_plot = proj[indices,:]
                
                betti_subspace_folder = betti_folder+f'/stim1_{num}'
                if not os.path.exists(betti_subspace_folder):
                    os.makedirs(betti_subspace_folder)
                results, bar_lens_list = run_ripser(proj_plot,betti_subspace_folder+f'/betti')
                h0_lens,h1_lens,h2_lens = bar_lens_list
                with open(betti_subspace_folder+f'/h0_lens.npy', 'wb') as f:
                    np.save(f, h0_lens)
                with open(betti_subspace_folder+f'/h1_lens.npy', 'wb') as f:
                    np.save(f, h1_lens)
                with open(betti_subspace_folder+f'/h2_lens.npy', 'wb') as f:
                    np.save(f, h2_lens)
                    
                    
                indices = ind_stim_loc2==num
                label_plot = ind_stim_loc1[indices]
                proj_plot = proj[indices,:]
                
                betti_subspace_folder = betti_folder+f'/stim2_{num}'
                if not os.path.exists(betti_subspace_folder):
                    os.makedirs(betti_subspace_folder)
                results, bar_lens_list = run_ripser(proj_plot,betti_subspace_folder+f'/betti')
                h0_lens,h1_lens,h2_lens = bar_lens_list
                with open(betti_subspace_folder+f'/h0_lens.npy', 'wb') as f:
                    np.save(f, h0_lens)
                with open(betti_subspace_folder+f'/h1_lens.npy', 'wb') as f:
                    np.save(f, h1_lens)
                with open(betti_subspace_folder+f'/h2_lens.npy', 'wb') as f:
                    np.save(f, h2_lens)

model folder: /Volumes/Seagate Backup Plus Drive/fyp/results/2_stim_batch_size_512/n_hidden_256/2_stim_batch_size_512_n_hidden_256_acc_60_seed_12_with_noise/
model folder: /Volumes/Seagate Backup Plus Drive/fyp/results/2_stim_batch_size_512/n_hidden_256/2_stim_batch_size_512_n_hidden_256_acc_60_seed_13_with_noise/
model folder: /Volumes/Seagate Backup Plus Drive/fyp/results/2_stim_batch_size_512/n_hidden_256/2_stim_batch_size_512_n_hidden_256_acc_60_seed_14_with_noise/
model folder: /Volumes/Seagate Backup Plus Drive/fyp/results/2_stim_batch_size_512/n_hidden_256/2_stim_batch_size_512_n_hidden_256_acc_60_seed_15_with_noise/
model folder: /Volumes/Seagate Backup Plus Drive/fyp/results/2_stim_batch_size_512/n_hidden_256/2_stim_batch_size_512_n_hidden_256_acc_60_seed_16_with_noise/
model folder: /Volumes/Seagate Backup Plus Drive/fyp/results/2_stim_batch_size_512/n_hidden_256/2_stim_batch_size_512_n_hidden_256_acc_60_seed_17_with_noise/
model folder: /Volumes/Seagate Backup Plus Drive/fyp

In [2]:
h1_lens

NameError: name 'h1_lens' is not defined