### Persistent homology examples

* Ripser [paper](https://www.theoj.org/joss-papers/joss.00925/10.21105.joss.00925.pdf) [code](https://github.com/scikit-tda/ripser.py) (fast)
* Dionysus 2 [code](https://mrzv.org/software/dionysus2/) (representative examples)
* Nico's [code](https://github.com/nhchristianson/Math-text-semantic-networks)
* Ann's [code](https://github.com/asizemore/PH_tutorial/blob/master/Tutorial_day1.ipynb)

## Load networks

In [None]:
%reload_ext autoreload
%autoreload 2
import os,sys
sys.path.insert(1, os.path.join(sys.path[0], '..', 'module'))

In [None]:
topics = ['anatomy', 'biochemistry', 'cognitive science', 'evolutionary biology',
          'genetics', 'immunology', 'molecular biology', 'chemistry', 'biophysics',
          'energy', 'optics', 'earth science', 'geology', 'meteorology',
          'philosophy of language', 'philosophy of law', 'philosophy of mind',
          'philosophy of science', 'economics', 'accounting', 'education',
          'linguistics', 'law', 'psychology', 'sociology', 'electronics',
          'software engineering', 'robotics',
          'calculus', 'geometry', 'abstract algebra',
          'Boolean algebra', 'commutative algebra', 'group theory', 'linear algebra',
          'number theory', 'dynamical systems and differential equations']

In [None]:
import wiki

path_saved = '/Users/harangju/Developer/data/wiki/graphs/dated/'

networks = {}
for topic in topics:
    print(topic, end=' ')
    networks[topic] = wiki.Net(path_graph=path_saved + topic + '.pickle',
                               path_barcodes=path_saved + topic + '.barcode')

In [None]:
path_null = '/Users/harangju/Developer/data/wiki/graphs/null-target/'
num_nulls = 2
null_targets = {}
for topic in topics:
    null_targets[topic] = [None for i in range(num_nulls)]
    for i in range(num_nulls):
        null_targets[topic][i] = wiki.Net(path_graph=path_null + topic + '-null-' + str(i) + '.pickle',
                                          path_barcodes=path_null + topic + '-null-' + str(i) + '.barcode')

In [None]:
path_null = '/Users/harangju/Developer/data/wiki/graphs/null-year/'
num_nulls = 2
null_years = {}
for topic in topics:
    null_years[topic] = [None for i in range(num_nulls)]
    for i in range(num_nulls):
        null_years[topic][i] = wiki.Net(path_graph=path_null + topic + '-null-' + str(i) + '.pickle',
                                        path_barcodes=path_null + topic + '-null-' + str(i) + '.barcode')

In [None]:
import pandas as pd

pd.options.display.max_rows = 12
null_targets['robotics'][0].barcodes

In [None]:
barcodes = pd.concat([network.barcodes.assign(topic=topic)\
                                      .assign(type='real')\
                                      .assign(null=-1)
                      for topic, network in networks.items()] +
                     [network.barcodes.assign(topic=topic)\
                                      .assign(type='null_targets')\
                                      .assign(null=i)
                      for topic, nulls in null_targets.items()
                          for i, network in enumerate(nulls)],
                     ignore_index=True, sort=False)
barcodes

## Plotting functions

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib._color_data as mcd
plt.rcParams.update({'figure.max_open_warning': 0})

sns.set(style='white', font_scale=1.4)
def plot_barcodes(barcodes):
    """
    """
    colors = [mcd.XKCD_COLORS['xkcd:'+c]
          for c in ['emerald green', 'tealish', 'peacock blue', 
                    'grey', 'brown', 'red', 'yellow', 'green']]
    plt.figure(figsize=(18,6))
    for i, row in barcodes.iterrows():
        birth = row['birth']
        death = row['death']
        x = [birth, 2050] if death==np.inf else\
            [birth, death]
        plt.plot(x, i*np.ones(len(x)), colors[row['dim']])
        if death != np.inf:
            plt.plot(death, i, 'rx')
    plt.axvline(x=2040, linestyle='--', color=mcd.XKCD_COLORS['xkcd:grey'])
    plt.gca().axes.yaxis.set_ticklabels([])

In [None]:
def plot_persistence_diagram(barcodes):
    colors = [mcd.XKCD_COLORS['xkcd:'+c]
      for c in ['emerald green', 'tealish', 'peacock blue', 
                'grey', 'brown', 'red', 'yellow']]
    plt.figure(figsize=(10,10))
    for dim in set(barcodes['dim']):
        data = barcodes.loc[barcodes['dim']==dim]
        data.loc[data['death']==np.inf,'death'] = 2030
        plt.plot(data['birth'], data['death'], '.')
    x = [barcodes['birth'].min(),
         barcodes.loc[barcodes['death']!=np.inf,'death'].max()]
    print(x)
    plt.plot(x, [2030, 2030], '--')

## Plot barcodes

In [None]:
# topic = 'geology'
# topic = 'meteorology'
# topic = 'electronics'
# topic = 'robotics'
topic = 'molecular biology'
[n for n in networks[topic].graph.nodes
 if networks[topic].graph.nodes[n]['year']>2100]

In [None]:
for topic in topics:
    print('Topic: ' + topic)
    plot_barcodes(networks[topic].barcodes[networks[topic].barcodes.lifetime!=0])
    plt.title(topic)
    plot_barcodes(null_targets[topic][0].barcodes[null_targets[topic][0].barcodes.lifetime!=0])
    plt.title('target-rewired')
    plot_barcodes(null_years[topic][0].barcodes[null_years[topic][0].barcodes.lifetime!=0])
    plt.title('year-reordered')
    plt.show()

## Compare lifetimes (real vs null)

In [None]:
from scipy import stats

In [None]:
for topic in topics:
    data = barcodes[barcodes.topic==topic].copy()
    data = data[data.lifetime!=np.inf]
    t, p = stats.ttest_ind(data[data.type=='real']['lifetime'].values,
                           data[data.type=='null_targets']['lifetime'].values)
    print(topic, '\n\t', 't =', t, '\tp ={:6.5f}'.format(p))

In [None]:
plt.figure(figsize=(20,6))
ax = sns.violinplot(x='topic', y='lifetime', hue='type', split=True,
                    data=barcodes[barcodes.lifetime!=np.inf])
plt.xticks(np.arange(len(topics)), topics, rotation='vertical');

In [None]:
plt.figure(figsize=(20,6))
data = barcodes.copy()
data = data.merge(data[data.lifetime!=np.inf].groupby('topic')['lifetime'].max(),
                  on='topic', suffixes=['','_max'])
data.loc[data.lifetime==np.inf,'lifetime'] = data.loc[data.lifetime==np.inf,'lifetime_max']
sns.violinplot(x='topic', y='lifetime', hue='type', split=True,
               data=data)
plt.xticks(np.arange(len(topics)), topics, rotation='vertical');

In [None]:
for i, topic in enumerate(topics):
    plt.figure(figsize=(20,4))
    lifetimes = networks[topic].barcodes.death.values - networks[topic].barcodes.birth.values
    sns.distplot([x if x!=np.inf else max(lifetimes[lifetimes!=np.inf]) for x in lifetimes],
                 hist=True, rug=True, label='real')
    lifetimes = null_targets[topic][0].barcodes.death.values - null_targets[topic][0].barcodes.birth.values
    sns.distplot([x if x!=np.inf else max(lifetimes[lifetimes!=np.inf]) for x in lifetimes],
                 hist=True, rug=True, label='null-target')
    plt.title(topic)
    plt.legend()

## Compare dimensions (real vs null)

In [None]:
plt.figure(figsize=(20,6))
sns.violinplot(x='dim', y='count', hue='type', split=True,
               data=barcodes.merge(barcodes.assign(count=1)\
                                           .groupby(['type','topic','dim'])['count'].sum(),
                                   on=['type','topic','dim']))

In [None]:
plt.figure(figsize=(6,6))
ax = sns.scatterplot(x='null_targets', y='real', hue='dim', palette='Pastel2',
                     data=barcodes.merge(barcodes.assign(count=1)\
                                           .groupby(['type','topic','dim'])['count'].sum(),
                                   on=['type','topic','dim'])
                                  .groupby(['topic','type','dim'], sort=False)['count'].mean()\
                                  .unstack(level=1)\
                                  .reset_index())
sns.lineplot(x=[0,8000], y=[0,8000], ax=ax, label='equal')
plt.legend(loc='upper center')
plt.ylim([-1000,8000])
plt.xlim([-1000,8000]);

## Node importance

### Node participation in birth & deaths

### Identify important nodes

## Lifetime vs Cavity volume

Useful resource
* [Computational topology](https://books.google.com/books?id=MDXa6gFRZuIC&printsec=frontcover#v=onepage&q=%22persistent%20homology%22&f=true)
* [tutorial](http://pages.cs.wisc.edu/~jerryzhu/pub/cvrghomology.pdf)

In [None]:
topic = 'biochemistry'
network = networks[topic]

In [None]:
import pickle
import numpy as np
import gensim.utils as gu
import gensim.matutils as gmat
import sklearn.metrics.pairwise as smp

In [None]:
for i, row in barcodes.iterrows():
    sys.stdout.write("\rindex: " + str(i+1) + '/' + str(len(barcodes.index)))
    sys.stdout.flush()
    nodes = row['homology nodes']
    topic = row['topic']
    network = networks[topic] if row['type']=='real' else null_targets[topic][0]
    tfidf = network.graph.graph['tfidf']
    indices = [network.nodes.index(n) for n in nodes]
    centroid = tfidf[:,indices].mean(axis=1) if indices else 0
    distances = smp.cosine_distances(X=tfidf[:,indices].transpose(), Y=centroid.transpose())\
                if indices else [0]
    barcodes.loc[i,'average distance'] = np.mean(distances)

In [None]:
barcodes

In [None]:
pickle.dump(barcodes, open('barcodes.pickle','wb'))

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
for topic, network in networks.items():
    plt.figure(figsize=(6,6))
    data = barcodes[(barcodes.topic==topic) & (barcodes.type=='real') & 
                    (barcodes.lifetime!=np.inf) & (barcodes.lifetime!=0)]
    r, p= s.stats.pearsonr(data['average distance'].values, data['lifetime'].values)
    sns.regplot(x='average distance', y='lifetime', marker='+', fit_reg=True, data=data)
    plt.title(topic + '\np = {:3.2f}, r = {:6.5f}'.format(r,p))

**Check empirical vs non-empirical sciences?**

## Lifetime vs Cavity weights

In [None]:
topic = 'biochemistry'
network = networks[topic]

In [None]:
sub = network.graph.subgraph(['Carbon','Alcohol'])
np.mean([sub.edges[n1,n2]['weight'] for n1,n2 in sub.edges])

In [None]:
for i, row in barcodes.iterrows():
    sys.stdout.write("\rindex: " + str(i+1) + '/' + str(len(barcodes.index)))
    sys.stdout.flush()
    nodes = row['homology nodes']
    topic = row['topic']
    network = networks[topic] if row['type']=='real' else null_targets[topic][0]
    subgraph = network.graph.subgraph(nodes)
    barcodes.loc[i,'mean edge weights'] = np.mean([subgraph.edges[u,v]['weight']
                                                   for u,v in subgraph.edges])
barcodes

In [None]:
for topic, network in networks.items():
    plt.figure(figsize=(6,6))
    data = barcodes[(barcodes.topic==topic) & (barcodes.type=='real') & 
                    (barcodes.lifetime!=np.inf) & (barcodes.lifetime!=0)].dropna()
    r, p = s.stats.pearsonr(data['mean edge weights'].values, data['lifetime'].values)\
            if len(data['lifetime'])>2 else (0,0)
#     sns.scatterplot(x='average distance', y='lifetime', data=data)
    sns.regplot(x='mean edge weights', y='lifetime', marker='+', fit_reg=True,
                data=data)
    plt.title(topic + ' ' + str((r,p)))

### Lifetime vs Mean weights of death simplex

In [None]:
mean_weights = []
for i in range(len(barcodes.index)):
    death_simplex = barcodes.iloc[i]['death simplex']
    topic = barcodes.iloc[i]['topic']
    network_type = barcodes.iloc[i]['type']
    pairs = [(n1,n2) for n2 in death_simplex
                     for n1 in death_simplex if n1!=n2]
    if network_type=='real':
        edges = [networks[topic].graph.get_edge_data(n1,n2) for n1,n2 in pairs]
    elif network_type=='null_targets':
        null = barcodes.iloc[i]['null']
        edges = [null_targets[topic][null].graph.get_edge_data(n1,n2) for n1,n2 in pairs]
    mean_weight = np.mean([e['weight'] for e in edges if e]) if edges else 0
    mean_weights.append(mean_weight)

In [None]:
barcodes['mean weights'] = mean_weights
barcodes

In [None]:
plt.figure(figsize=(10,10))
data = barcodes.copy()
# data.loc[data.lifetime==np.inf,'lifetime'] = max(data[data.lifetime!=np.inf].lifetime.values)
data = data[data.lifetime!=np.inf]
ax = sns.regplot(x='mean weights', y='lifetime',
                 data=data, marker='.')
a, b, r, p, s = stats.linregress(data['mean weights'], data['lifetime'])
plt.title('r={:.4f}, p={:.4f}'.format(r, p))

In [None]:
for topic in topics:
    plt.figure(figsize=(6,6))
    data = barcodes[barcodes.topic==topic].copy()
#     data.loc[data.lifetime==np.inf,'lifetime'] = max(data[data.lifetime!=np.inf].lifetime.values)
    data = data[data.lifetime!=np.inf]
    sns.regplot(x='mean weights', y='lifetime', data=data,
                marker='+', fit_reg=True)
    a, b, r, p, s = stats.linregress(data['mean weights'], data['lifetime'])
    plt.title('{}\nr={:.4f}, p={:.4f}'.format(topic, r, p))