In [None]:
%load_ext autoreload

import sys, os
sys.path.insert(0, '../')
sys.path.insert(0, '../python_src/')


import numpy as np
import numpy.linalg as la
import scipy as sp
import scipy.io as spio

import matplotlib.pyplot as plt; plt.rcdefaults()
import matplotlib as mpl
import seaborn as sns
import matplotlib.colors as mcolors


import itertools as it
import queue
import networkx as nx
import skimage.morphology as morph

import homology

mpl.rcParams['mathtext.fontset'] = 'cm'
sns.set_context('poster', font_scale=1.25)
sns.set(color_codes=True, palette='deep')
sns.set_style('ticks', {'xtick.direction': 'in','ytick.direction': 'in', 'axes.linewidth': 2.0})

In [None]:
mat = spio.loadmat("../sample_data/Z17.mat")
# data = mat['Z17'][2400:2475, 2400:2475][52:58, 3:8]
data = mat['Z17'][1500:2000, 1500:2000]

# mat = spio.loadmat("../sample_data/thresholded/Creases15.mat")
# data = mat['ridges'][1750:2000, 1750:2000].astype(np.int8)

print(data.shape)

# print(np.min(data), np.max(data))


# data = np.array([[7, 6, 5, 4],
#                 [8, 5, 4, 3],                        
    
#                 [9, 4, 3, 2],
#                 [0, 3, 2, 1]])

# data = np.array([[3,2,2],
#                 [3,1,1],
#                 [0,1,0]])

# data = np.array([[8, 3, 2],
#                 [7, 8, 1],
#                 [6, 8, 1]
#                 [5, 10, 0]])

# # data = mat['ridges'][1000:2000, 1000:2000].astype(np.int8)+mat['valleys'][1000:2000, 1000:2000].astype(np.int8)

# # data = mat['ridges'].astype(np.int8)

fig, ax = plt.subplots(figsize=(16,16))
im = ax.imshow(data, cmap=plt.cm.Blues_r)
ax.axis('off')
# plt.colorbar(im)
plt.show()

In [None]:
# %autoreload

print("Constructing complex")
comp = homology.construct_cubical_complex(data.shape)

print("Constructing cofacets")
comp.construct_cofacets()

print("Constructing discrete gradient")

# heights = np.argsort(np.argsort(data.flatten()))
heights = data.flatten()

V = homology.construct_discrete_gradient(comp, heights)

# print(V)

# for v in V:
#     if v == V[v]:
#         print(v, comp.dims[v])
   
print("Reversing gradient")
coV = homology.reverse_discrete_gradient(V)

print("Finding basins")
basins = homology.find_basins(coV, comp.cofacets, comp.dims, 0)


print("Complete")

In [None]:

palette = it.cycle(sns.color_palette("deep"))
# palette = it.cycle(sns.color_palette("Blues"))

image = np.ones([data.shape[0]*data.shape[1], 3])

for i in basins:

    color = next(palette)
    image[list(basins[i])] = color

    
print("Calculating Morse skeleton")
skeleton = homology.find_morse_skeleton(V, coV, comp.facets, comp.cofacets, comp.dims, 1)
image[list(skeleton)] = (0.0, 0.0, 0.0)

# for i in basins:
#     image[i] = (1.0, 1.0, 1.0)
        
print(np.min(data), np.max(data))
        

fig, ax = plt.subplots(figsize=(16, 16))
ax.imshow(image.reshape((data.shape[0], data.shape[1], 3)))


ax.axis('off')
plt.show()

In [None]:
print("Calculating Morse complex")
mcomp = homology.construct_morse_complex(V, comp.facets, comp)

mcomp.construct_cofacets()

# print(mcomp.facets)
# print(mcomp.dims)

# print(mcomp.cofacets)

print("Calculating persistence pairs...")

filtration = homology.construct_filtration(comp.facets, comp.dims, mcomp, heights)

(pairs, cell_index) = homology.compute_persistence_pairs(mcomp, filtration, show_zero=True)

# print(pairs)

birth = [[] for i in range(mcomp.dim+1)]
death = [[] for i in range(mcomp.dim+1)]

for (i, j) in pairs:
    if j is None or cell_index[i][0] == cell_index[j][0]:
        continue
    
    d = mcomp.dims[i]
    birth[d].append(cell_index[i][1])
    if j is not None:
        death[d].append(cell_index[j][1])
    else:
        death[d].append(cell_index[i][1])

print("Complete")

In [None]:
for d in range(3):
    
    fig = plt.figure(figsize=(8,8))
    
    ax1 = fig.add_subplot(1,1,1)

    ax1.scatter(birth[d], death[d], marker='.')
    
    ax1.plot(np.linspace(np.min(heights), np.max(heights), 100), np.linspace(np.min(heights), np.max(heights), 100), 'k--')
    
    ax1.set_title(r"$d={}$".format(d))
    
    ax1.set_xlabel(r"Birth [height]")
    ax1.set_ylabel(r"Death [height]")
        
    plt.tight_layout()

    plt.show()


In [None]:
for d in range(2):

    fig = plt.figure(figsize=(8,8))
    
    ax1 = fig.add_subplot(1,1,1)

    ax1.set_yscale('log')
    
    ax1.scatter(birth[d], np.array(death[d])-np.array(birth[d]), marker='.')
            
    ax1.set_title(r"$d={}$".format(d))

    ax1.set_xlabel(r"Birth [height]")
    ax1.set_ylabel(r"Persistence [height]")
        
    plt.tight_layout()

    plt.show()

In [None]:
persistence = []
area = []

for (i, j) in pairs:
    if j is None or mcomp.dims[i] != 0:
        continue
            
    persistence.append(cell_index[j][1] - cell_index[i][1])
    seen = homology.level_bfs(i, cell_index[j][1], mcomp.facets, mcomp.cofacets, cell_index)
    
    A = 0
    for k in seen:
        A += len(basins[k])
        
    area.append(A)
        

# g = sns.JointGrid(np.array(persistence), np.array(area))
# ax = g.ax_joint
# ax.set_xscale('log')
# ax.set_yscale('log')

# g.plot_joint(plt.scatter, marker='.')

# g.ax_marg_x.hist(np.array(persistence), bins=np.logspace(np.log10(np.min(persistence)), 
#                                                          np.log10(np.max(persistence)), 32))

# g.ax_marg_y.hist(np.array(area), bins=np.logspace(np.log10(np.min(area)), 
#                                                          np.log10(np.max(area)), 32), orientation="horizontal")

g = sns.JointGrid(np.log10(persistence), np.log10(area))

# g.plot_joint(plt.hexbin, cmap="Blues", gridsize=16, bins='log')
g.plot_joint(plt.scatter, marker='.', s=8, color='b')

g.plot_marginals(sns.distplot, kde=False)

ax = g.ax_joint
ax.set_xlabel(r"$\log_{10}$Persistence [height]")
ax.set_ylabel(r'$\log_{10}$Area [pixels$^2$]')


plt.show()

In [None]:
# %autoreload

persistence = []
for (i, j) in pairs:
    if j is not None:
        
        persistence.append((cell_index[j][1] - cell_index[i][1], (i, j)))

persistence = sorted(persistence, reverse=True)

print(len(persistence))

In [None]:
# %autoreload

palette = it.cycle(sns.color_palette("deep"))
image = np.ones([data.shape[0]*data.shape[1], 3])
image[:] = (0.5, 0.5, 0.5)

count = 0
n = 0
while count < 10:

    (p, (i, j)) = persistence[n]
        
    n += 1
#     print(n)
    if mcomp.dims[i] == 0:
        print(p)
        count += 1
        
#         print(p, (i, j))
        
        seen = homology.level_bfs(i, cell_index[j][1], mcomp.facets, mcomp.cofacets, cell_index)
        
#         print(seen)
        
        color = next(palette)
        for i in seen:
            image[list(basins[i])] = color
            
#         image[list(skeleton)] = (0.0, 0.0, 0.0)

fig, ax = plt.subplots(figsize=(16, 16))
ax.imshow(image.reshape((data.shape[0], data.shape[1], 3)))


ax.axis('off')
plt.show()
        
        

In [None]:
print(max(set()))