<a href="https://colab.research.google.com/github/mtsizh/bottleneck-distance-for-sigma8/blob/main/wasserstein_distance.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#@title #Download datasets
#@markdown Run this block of code to download our dataset to Google Colab.
#@markdown If successful this code yields 
#@markdown * folder `alldata2` containing all data for further scripts
#@markdown * file `alldata_large.zip` -- the same data zipped; not needed for scripts, but you may want to download it to your local machine

from IPython.display import clear_output

files = ['alldata.zip', 'alldata.z01', 'alldata.z02', 'alldata.z03']
print('deleting old files if exist')
!rm -rf alldata2
!rm -f alldata_large.zip
for f in files:
  !rm -f {f}
for i,f in enumerate(files):
  print('downloading part ', i+1, ' of ', len(files))
  #result = !curl -s -S https://raw.githubusercontent.com/mtsizh/bottleneck-distance-for-sigma8/main/{f} > /dev/null && echo "TRUE"
  result = !wget -q https://raw.githubusercontent.com/mtsizh/bottleneck-distance-for-sigma8/main/{f} && echo "1" || echo "0"
  if result == ['1']:
    print('OK')
  else:
    raise Exception("ERROR WHILE DOWNLOADING DATA")
clear_output()
print("DOWNLOAD SUCCESSFUL")
print("PARTS TO ZIP")
result = !zip -qq -F alldata.zip --out alldata_large.zip 2>/dev/null && echo "1" || echo "0"
if result == ['1']:
  print('OK')
else:
  raise Exception("ERROR WHILE CREATING LARGE ZIP")
print("UNZIPPING")
result = !unzip -o -qq alldata_large.zip 2>/dev/null && echo "1" || echo "0"
if result == ['1']:
  print('OK')
else:
  raise Exception("ERROR WHILE UNZIPPING")
print('cleaning')
for f in files:
  !rm -f {f}
clear_output()
print("ALL DATA DOWNLOADED AND UNPACKED")

ALL DATA DOWNLOADED AND UNPACKED


install gudhi library with the following code

In [2]:
!pip install gudhi

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting gudhi
  Downloading gudhi-3.7.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (31.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.6/31.6 MB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: gudhi
Successfully installed gudhi-3.7.1


The following code calculates persistence intervals. Results are saved as `persistence.json` file. Computations are lengthy, please wait until the progressbar reaches 100%. For your convenience `persistence.json.zip` is created so that you can download intermediate data to your computer.

In [3]:
#@markdown Folder with data (should be ok by default)
root_folder = "./alldata2" #@param {type: "string"}
#@markdown Choose mass filtering in % (0 to 100)
remove_masses_lower_than = 0  #@param {type: "slider", min: 0, max: 100}
remove_masses_higher_than = 100  #@param {type: "slider", min: 0, max: 100}
#@markdown Choose random filtering in % (0 to 100)
bootstrap_percent = 4  #@param {type: "slider", min: 0, max: 100}
#@markdown File to save results, use JSON only! (should be ok by default)
save_to = "persistence.json" #@param {type: "string"}

if remove_masses_higher_than <= remove_masses_lower_than:
  raise Exception("Upper masses limit should be larger then the lower limit")

import os
from glob import glob
import re
import numpy as np
import pandas as pd
import gudhi
from ipywidgets import IntProgress
from IPython import display


def read_all_data(filename, bootstrap_percent=100, mass_filter=[0, 100]):
  df = pd.read_csv(filename, sep=r'\t', header=None, engine ='python')
  df.drop(df.columns.difference([0,1,2,3]), axis=1, inplace=True)
  df.rename(columns={0:'x', 1:'y', 2:'z', 3:'m'}, inplace=True)
  min_m = df['m'].quantile(mass_filter[0]/100)
  max_m = df['m'].quantile(mass_filter[1]/100)
  df = df[df['m'].between(min_m, max_m)]
  df.drop(['m'], axis=1, inplace=True)
  all_arr = df.to_numpy()
  return all_arr[np.random.choice(all_arr.shape[0], replace=False,
                           size=all_arr.shape[0]*bootstrap_percent//100)]

def get_persistence_intervals(point_set, dim):
  alpha_complex = gudhi.AlphaComplex(points=point_set)
  simplex_tree = alpha_complex.create_simplex_tree(default_filtration_value=False)
  simplex_tree.compute_persistence()
  persistence_intervals = simplex_tree.persistence_intervals_in_dimension(dim)
  return persistence_intervals

dat_files = [y for x in os.walk(root_folder) for y in glob(os.path.join(x[0], '*.dat'))]
reg_expr = '.*\/([0-9]+)groups_([0-9]+)_new.dat'
parsed = [{'filename': f, 
           'sigma': int(re.match(reg_expr, f).group(1)), 
           'red_shift': int(re.match(reg_expr, f).group(2))
           } for f in dat_files]
all_data = pd.DataFrame(columns=["sigma8*10", "red_shift", "dimension", "born", "persists"])

bar = IntProgress(min=0, max=len(parsed))
bar.value = 0
display.display(bar)

for f in parsed:
  data = read_all_data(f['filename'], bootstrap_percent, 
                       [remove_masses_lower_than, remove_masses_higher_than])
  for dim in [0, 1, 2]:
    I = get_persistence_intervals(data, dim)
    df = {'sigma8*10': f['sigma'], 
          'red_shift': f['red_shift'], 
          'dimension': dim,
          'born': I[:,0].tolist(),
          'dies': I[:,1].tolist(),
          'persists': (I[:,1] - I[:,0]).tolist()}
    all_data = all_data.append(df, ignore_index = True)
  bar.value += 1 

all_data.to_json(save_to)
print("JSON creation complete!")

result =!zip -qq -r {save_to}.zip {save_to} 2>/dev/null && echo "1" || echo "0"
if result == ['1']:
  print('zipped')
else:
  raise Exception("ERROR WHILE CREATING LARGE ZIP")

IntProgress(value=0, max=144)

JSON creation complete!
zipped


The following code visualizes results of the previous calculation.

In [4]:
#@markdown File to load data (should be ok by default)
read_from = "persistence.json" #@param {type: "string"}


import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set_style("whitegrid")

df = pd.read_json(read_from)
df = df.groupby(['sigma8*10', 'red_shift', 'dimension']).agg(lambda x: sum(list(x), [])).reset_index()


def plot_by_params(dim, sig):
  plt.figure(figsize=(12,7))
  sns.set_context("notebook", font_scale=1.5)

  plt.title("$\sigma_8 = "+str(sig/10)+" $, dimension: "+str(dim))
  plt.xlabel("Birth")
  plt.ylabel("Persistence")

  pick_df = df[(df['sigma8*10'] == sig) & (df['dimension'] == dim)]
  r_shifts = pick_df['red_shift'].unique()
  cmap = plt.cm.get_cmap('plasma')

  for idx,red_shift in enumerate(r_shifts):
    X = np.array(pick_df[pick_df['red_shift'] == red_shift]['born'].values[0])
    Y = np.array(pick_df[pick_df['red_shift'] == red_shift]['persists'].values[0])
    plt.scatter(X, Y, color=cmap(idx/len(r_shifts)), label="red shift: "+str(red_shift))

  plt.legend()
  plt.show()
  

import ipywidgets as widgets

def make_widget(field):
  return widgets.Dropdown(
    options=df[field].unique(),
    #value='2',
    description=field,
    disabled=False,
  )

from ipywidgets import interactive_output

w = [make_widget(field) for field in ['dimension', 'sigma8*10']]
ui = widgets.HBox(w)
out = widgets.interactive_output(plot_by_params, {'dim': w[0], 
                                                  'sig': w[1]})
display.display(ui, out)


HBox(children=(Dropdown(description='dimension', options=(0, 1, 2), value=0), Dropdown(description='sigma8*10'…

Output()

In [15]:
order = 1
read_from = "persistence.json"
save_to = "distances.json"

import numpy as np
import pandas as pd
import gudhi
import gudhi.hera

df = pd.read_json(read_from)
result = pd.DataFrame(columns=["Dsigma8*10", "red_shift", "dimension", "distances"])

bar = IntProgress(min=0, max=len(df.index))
bar.value = 0
display.display(bar)

for r_shift in df['red_shift'].unique():
  df1 = df[df['red_shift'] == r_shift]
  for dimension in df['dimension'].unique():
    df2 = df1[df1['dimension'] == dimension]
    distances = {}
    for idx1 in range(len(df2.index)):
      I1 = np.transpose([df2.iloc[idx1]['born'], df2.iloc[idx1]['dies']]).astype(float)
      sigma81 = df2.iloc[idx1]['sigma8*10']
      for idx2 in range(len(df2.index)):
        if idx2 <= idx1:
          continue
        I2 = np.transpose([df2.iloc[idx2]['born'], df2.iloc[idx2]['dies']]).astype(float)
        sigma82 = df2.iloc[idx2]['sigma8*10']
        #dist = float(gudhi.bottleneck_distance(I1, I2))
        dist = float(gudhi.hera.wasserstein_distance(I1, I2, order))
        dsigma = np.abs(sigma81 - sigma82)
        if not dsigma in distances:
          distances[dsigma] = []
        distances[dsigma].append(dist)
      bar.value += 1
    for dsigma in distances:
      row = {'Dsigma8*10': dsigma, 
             'red_shift': r_shift, 
             'dimension': dimension,
             'distances': distances[dsigma]}
      result = result.append(row, ignore_index = True)

result.to_json(save_to)
print("JSON creation complete!")

IntProgress(value=0, max=432)

JSON creation complete!


In [17]:
read_from = "distances.json"

import numpy as np
import pandas as pd
import gudhi
import gudhi.hera

df = pd.read_json(read_from)
df

Unnamed: 0,Dsigma8*10,red_shift,dimension,distances
0,3,20,0,"[0.0743839904, 0.0760524798, 0.1249733163, 0.1..."
1,1,20,0,"[0.08318562360000001, 0.052760789200000005, 0...."
2,4,20,0,"[0.0777298632, 0.094852747, 0.0793013442, 0.07..."
3,2,20,0,"[0.083525207, 0.063229327, 0.06035517100000000..."
4,5,20,0,"[0.0744399746, 0.06794893070000001, 0.07571615..."
5,0,20,0,"[0.0527908559, 0.07764946040000001, 0.09399612..."
6,3,20,1,"[0.16999466, 0.1794660746, 0.1533517674, 0.161..."
7,1,20,1,"[0.1480353797, 0.2033783999, 0.1633751071, 0.1..."
8,4,20,1,"[0.14736843260000002, 0.1864187418, 0.18428067..."
9,2,20,1,"[0.16589547300000002, 0.1587500746, 0.17902874..."


In [136]:
read_from = "distances.json"

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set_style("whitegrid")
from matplotlib.patches import Patch
from matplotlib.lines import Line2D

df = pd.read_json(read_from)

legend_elements = [
                    Line2D([0], [0], marker='o', color='w', 
                          label='$0$ homologies',
                          markerfacecolor='r', markersize=10),
                   Line2D([0], [0], marker='o', color='w', 
                          label='$1$ homologies',
                          markerfacecolor='b', markersize=10),
                  Line2D([0], [0], marker='o', color='w', 
                          label='$2$ homologies',
                          markerfacecolor='g', markersize=10)
                   ]

def plot_data(ax, data, box_color='b', shift=0.0):
  ax.set_xlabel("$\Delta \sigma_8$")
  ax.set_ylabel("$W$ distance")
  y = data['distances'].tolist()
  x = data['Dsigma8*10'].tolist()  
  y = [y[i] for i in np.argsort(x)]
  xu = np.sort(x)/10
  medians = [np.median(yd) for yd in y]
  import warnings
  warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) 
  ax.boxplot(y, 
              showmeans=True, positions=xu + shift, widths=0.006,
              boxprops=dict(color=box_color),
              whiskerprops=dict(color=box_color),
              capprops=dict(color=box_color),
              flierprops=dict(color=box_color, markeredgecolor=box_color))
  ax.plot(xu + shift, medians, c=box_color)
  ax.set_xlim(np.min(xu)-0.025, np.max(xu)+0.025)
  ax.set_navigate(False)
  ax.set_xticks(xu)
  ax.set_xticklabels(xu)


def plot_by_params(red_s, Dsigma8):
  sns.set_context("notebook", font_scale=1.5)
  fig, axs = plt.subplots(1, 2, figsize=(20,6))
  df_filtered = df[df['red_shift'] == red_s]
  dims = np.unique(df_filtered['dimension'])
  for d in dims:
    data = df_filtered[df_filtered['dimension'] == d]
    plot_data(axs[0], data, 
              box_color={0:'r', 1:'b', 2:'g'}[d],
              shift={0:-0.005, 1:0.0, 2:+0.005}[d])
  axs[0].legend(handles=legend_elements, loc='upper left', prop={'size': 12})
  Dsigma810 = int(Dsigma8*10)
  plt.show()
  
import ipywidgets as widgets

def make_widget(options, descr):
  return widgets.Dropdown(
    options=options,
    #value='2',
    description=descr,
    disabled=False,
  )

from ipywidgets import interactive_output

w = [make_widget(np.unique(df['red_shift']), 'red shift'),
     make_widget(np.unique(df['Dsigma8*10'])/10, 'Delta sigma8 x 10$')]
ui = widgets.HBox(w)
out = widgets.interactive_output(plot_by_params, {'red_s': w[0],
                                                  'Dsigma8': w[1]})
display.display(ui, out)



HBox(children=(Dropdown(description='red shift', options=(13, 15, 20), value=13), Dropdown(description='Delta …

Output()

In [115]:
df_filtered = df[df['red_shift'] == 13]
data = df_filtered[df_filtered['dimension'] == 0]
data

Unnamed: 0,Dsigma8*10,red_shift,dimension,distances
36,3,13,0,"[0.059265485900000005, 0.08853881570000001, 0...."
37,1,13,0,"[0.19487002250000002, 0.0700731789, 0.21193032..."
38,4,13,0,"[0.10262976260000001, 0.1076690648, 0.11476429..."
39,2,13,0,"[0.0874694067, 0.07132520440000001, 0.09240037..."
40,0,13,0,"[0.1238520134, 0.1223192372, 0.0676546709, 0.0..."
41,5,13,0,"[0.2694663071, 0.269346401, 0.2642412292, 0.21..."


In [119]:
idxs = np.argsort(data['Dsigma8*10'].tolist())
np.array(data['Dsigma8*10'].tolist())[idxs]

array([0, 1, 2, 3, 4, 5])

In [116]:
data['distances']

36    [0.059265485900000005, 0.08853881570000001, 0....
37    [0.19487002250000002, 0.0700731789, 0.21193032...
38    [0.10262976260000001, 0.1076690648, 0.11476429...
39    [0.0874694067, 0.07132520440000001, 0.09240037...
40    [0.1238520134, 0.1223192372, 0.0676546709, 0.0...
41    [0.2694663071, 0.269346401, 0.2642412292, 0.21...
Name: distances, dtype: object