<a href="https://colab.research.google.com/github/fbeilstein/algorithms/blob/master/wasserstein_distance.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title #Download datasets
#@markdown Run this block of code to download our dataset to Google Colab.
#@markdown If successful this code yields 
#@markdown * folder `alldata2` containing all data for further scripts
#@markdown * file `alldata_large.zip` -- the same data zipped; not needed for scripts, but you may want to download it to your local machine

from IPython.display import clear_output

files = ['alldata.zip', 'alldata.z01', 'alldata.z02', 'alldata.z03']
print('deleting old files if exist')
!rm -rf alldata2
!rm -f alldata_large.zip
for f in files:
  !rm -f {f}
for i,f in enumerate(files):
  print('downloading part ', i+1, ' of ', len(files))
  #result = !curl -s -S https://raw.githubusercontent.com/mtsizh/bottleneck-distance-for-sigma8/main/{f} > /dev/null && echo "TRUE"
  result = !wget -q https://raw.githubusercontent.com/mtsizh/bottleneck-distance-for-sigma8/main/{f} && echo "1" || echo "0"
  if result == ['1']:
    print('OK')
  else:
    raise Exception("ERROR WHILE DOWNLOADING DATA")
clear_output()
print("DOWNLOAD SUCCESSFUL")
print("PARTS TO ZIP")
result = !zip -qq -F alldata.zip --out alldata_large.zip 2>/dev/null && echo "1" || echo "0"
if result == ['1']:
  print('OK')
else:
  raise Exception("ERROR WHILE CREATING LARGE ZIP")
print("UNZIPPING")
result = !unzip -o -qq alldata_large.zip 2>/dev/null && echo "1" || echo "0"
if result == ['1']:
  print('OK')
else:
  raise Exception("ERROR WHILE UNZIPPING")
print('cleaning')
for f in files:
  !rm -f {f}
clear_output()
print("ALL DATA DOWNLOADED AND UNPACKED")

install gudhi library with the following code

In [None]:
!pip install gudhi

The following code calculates persistence intervals. Results are saved as `persistence.json` file. Computations are lengthy, please wait until the progressbar reaches 100%. For your convenience `persistence.json.zip` is created so that you can download intermediate data to your computer.

In [3]:
#@markdown Folder with data (should be ok by default)
root_folder = "./alldata2" #@param {type: "string"}
#@markdown Choose mass filtering in % (0 to 100)
remove_masses_lower_than = 0  #@param {type: "slider", min: 0, max: 100}
remove_masses_higher_than = 100  #@param {type: "slider", min: 0, max: 100}
#@markdown Choose random filtering in % (0 to 100)
bootstrap_percent = 100  #@param {type: "slider", min: 0, max: 100}
#@markdown File to save results, use JSON only! (should be ok by default)
save_to = "persistence.json" #@param {type: "string"}

if remove_masses_higher_than <= remove_masses_lower_than:
  raise Exception("Upper masses limit should be larger then the lower limit")

import os
from glob import glob
import re
import numpy as np
import pandas as pd
import gudhi
from ipywidgets import IntProgress
from IPython import display


def read_all_data(filename, bootstrap_percent=100, mass_filter=[0, 100]):
  df = pd.read_csv(filename, sep=r'\t', header=None, engine ='python')
  df.drop(df.columns.difference([0,1,2,3]), axis=1, inplace=True)
  df.rename(columns={0:'x', 1:'y', 2:'z', 3:'m'}, inplace=True)
  min_m = df['m'].quantile(mass_filter[0]/100)
  max_m = df['m'].quantile(mass_filter[1]/100)
  df = df[df['m'].between(min_m, max_m)]
  df.drop(['m'], axis=1, inplace=True)
  all_arr = df.to_numpy()
  return all_arr[np.random.choice(all_arr.shape[0], replace=False,
                           size=all_arr.shape[0]*bootstrap_percent//100)]

def get_persistence_intervals(point_set, dim):
  alpha_complex = gudhi.AlphaComplex(points=point_set)
  simplex_tree = alpha_complex.create_simplex_tree(default_filtration_value=False)
  simplex_tree.compute_persistence()
  persistence_intervals = simplex_tree.persistence_intervals_in_dimension(dim)
  return persistence_intervals

dat_files = [y for x in os.walk(root_folder) for y in glob(os.path.join(x[0], '*.dat'))]
reg_expr = '.*\/([0-9]+)groups_([0-9]+)_new.dat'
parsed = [{'filename': f, 
           'sigma': int(re.match(reg_expr, f).group(1)), 
           'red_shift': int(re.match(reg_expr, f).group(2))
           } for f in dat_files]
all_data = pd.DataFrame(columns=["sigma8*10", "red_shift", "dimension", "born", "persists"])

bar = IntProgress(min=0, max=len(parsed))
bar.value = 0
display.display(bar)

for f in parsed:
  data = read_all_data(f['filename'], bootstrap_percent, 
                       [remove_masses_lower_than, remove_masses_higher_than])
  for dim in [0, 1, 2]:
    I = get_persistence_intervals(data, dim)
    df = {'sigma8*10': f['sigma'], 
          'red_shift': f['red_shift'], 
          'dimension': dim,
          'born': I[:,0].tolist(),
          'dies': I[:,1].tolist(),
          'persists': (I[:,1] - I[:,0]).tolist()}
    all_data = all_data.append(df, ignore_index = True)
  bar.value += 1 

data = all_data.groupby(['sigma8*10', 'red_shift', 'dimension']).agg(lambda x: sum(list(x), [])).reset_index()
data.to_json(save_to)
print("JSON creation complete!")

result =!zip -qq -r {save_to}.zip {save_to} 2>/dev/null && echo "1" || echo "0"
if result == ['1']:
  print('zipped')
else:
  raise Exception("ERROR WHILE CREATING LARGE ZIP")

IntProgress(value=0, max=144)

JSON creation complete!
zipped


The following code visualizes results of the previous calculation.

In [28]:
#@markdown File to load data (should be ok by default)
read_from = "persistence.json" #@param {type: "string"}


import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set_style("whitegrid")

df = pd.read_json(read_from)


def plot_by_params(dim, sig):
  plt.figure(figsize=(12,7))
  sns.set_context("notebook", font_scale=1.5)

  plt.title("$\sigma_8 = "+str(sig/10)+" $, dimension: "+str(dim))
  plt.xlabel("Birth")
  plt.ylabel("Persistence")

  pick_df = df[(df['sigma8*10'] == sig) & (df['dimension'] == dim)]
  r_shifts = pick_df['red_shift'].unique()
  cmap = plt.cm.get_cmap('plasma')

  for idx,red_shift in enumerate(r_shifts):
    X = np.array(pick_df[pick_df['red_shift'] == red_shift]['born'].values[0])
    Y = np.array(pick_df[pick_df['red_shift'] == red_shift]['persists'].values[0])
    plt.scatter(X, Y, color=cmap(idx/len(r_shifts)), label="red shift: "+str(red_shift))

  plt.legend()
  plt.show()
  

import ipywidgets as widgets

def make_widget(field):
  return widgets.Dropdown(
    options=df[field].unique(),
    #value='2',
    description=field,
    disabled=False,
  )

from ipywidgets import interactive_output

w = [make_widget(field) for field in ['dimension', 'sigma8*10']]
ui = widgets.HBox(w)
out = widgets.interactive_output(plot_by_params, {'dim': w[0], 
                                                  'sig': w[1]})
display.display(ui, out)

HBox(children=(Dropdown(description='dimension', options=(0, 1, 2), value=0), Dropdown(description='sigma8*10'…

Output()