In [None]:
%load_ext autoreload

import sys, os
sys.path.insert(0, '../')
sys.path.insert(0, '../python_src/')


import numpy as np
import numpy.linalg as la
import scipy as sp
import scipy.io as spio

import matplotlib.pyplot as plt; plt.rcdefaults()
import matplotlib as mpl
import seaborn as sns
import matplotlib.colors as mcolors


import itertools as it
import queue
import networkx as nx
import skimage.morphology as morph

import homology

mpl.rcParams['mathtext.fontset'] = 'cm'
sns.set_context('poster', font_scale=1.25)
sns.set(color_codes=True, palette='deep')
sns.set_style('ticks', {'xtick.direction': 'in','ytick.direction': 'in', 'axes.linewidth': 2.0})

In [None]:
# mat = spio.loadmat("../sample_data/Z17.mat")
# data = mat['Z17'][2400:2475, 2400:2475][52:58, 3:8]
data = mat['Z17'][1975:2000, 1975:2000]

# mat = spio.loadmat("../sample_data/thresholded/Creases15.mat")
# data = mat['ridges'][1750:2000, 1750:2000].astype(np.int8)


# print(np.min(data), np.max(data))


# data = np.array([[10, 0, 10],
#                 [2, 1, 2],                        
#                 [3, 6, 3],
#                 [4, 5, 4]])

# data = np.array([[3,2,2],
#                 [3,1,1],
#                 [0,1,0]])

# data = np.array([[8, 3, 2],
#                 [7, 8, 1],
#                 [6, 8, 1]
#                 [5, 10, 0]])

# # data = mat['ridges'][1000:2000, 1000:2000].astype(np.int8)+mat['valleys'][1000:2000, 1000:2000].astype(np.int8)

# # data = mat['ridges'].astype(np.int8)

print(data.shape)


fig, ax = plt.subplots(figsize=(16,16))
im = ax.imshow(data, cmap=plt.cm.Blues_r)
ax.axis('off')
# plt.colorbar(im)
plt.show()

In [None]:
%autoreload

print("Constructing complex")
comp = homology.construct_cubical_complex(data.shape, oriented=True)

# print(comp.facets)
# print(comp.dims)

homology.check_boundary_op(comp)

print("Constructing cofacets")
comp.construct_cofacets()

# print(comp.cofacets)

print("Constructing discrete gradient")

heights = np.argsort(np.argsort(data.flatten()))
# heights = data.flatten()

V = homology.construct_discrete_gradient(comp, heights)

# print(V)

# for v in V:
#     if v == V[v]:
#         print(v, comp.dims[v])

print("Reversing gradient")
coV = homology.reverse_discrete_gradient(V)

# print(coV)

print("Finding basins")
basins = homology.find_basins(coV, comp.cofacets, comp.dims, 0)
valid = False
# print(basins)

print("Complete")

In [None]:
%autoreload
print("Calculating Morse complex")
mcomp = homology.construct_morse_complex(V, comp.facets, comp, oriented=True)

print("Checking Boundary Operator")
homology.check_boundary_op(comp)

mcomp.construct_cofacets()

# print(mcomp.facets)
# print(mcomp.dims)

# print(mcomp.cofacets)

print("Complete")

In [None]:
%autoreload

palette = it.cycle(sns.color_palette("deep"))
# palette = it.cycle(sns.color_palette("Blues"))

image = np.ones([data.shape[0]*data.shape[1], 3])

for i in basins:

    color = next(palette)
    image[list(basins[i])] = color

    
print("Calculating Morse skeleton")
skeleton = homology.find_morse_skeleton(mcomp, V, coV, comp.facets, comp.cofacets, comp.dims, 1)
image[list(skeleton)] = (0.0, 0.0, 0.0)

# for i in basins:
#     image[i] = (1.0, 1.0, 1.0)
        
print(np.min(data), np.max(data))
        

fig, ax = plt.subplots(figsize=(16, 16))
ax.imshow(image.reshape((data.shape[0], data.shape[1], 3)))


ax.axis('off')
plt.show()

In [None]:
%autoreload

print("Calculating persistence pairs...")

heights = data.flatten()
filtration = homology.construct_filtration(comp.facets, comp.dims, mcomp, heights)

# # (pairs, cell_index) = homology.compute_persistence(mcomp, filtration, show_zero=True)
(pairs, cell_index, bcycles, ocycles) = homology.compute_persistence(mcomp, filtration, show_zero=False, 
                                                            birth_cycles=True, optimal_cycles=True)

print(pairs)
print(bcycles)

birth = [[] for i in range(mcomp.dim+1)]
death = [[] for i in range(mcomp.dim+1)]

for (i, j) in pairs:
#     if j is None:
#         continue
#     if cell_index[i][0] == cell_index[j][0]:
#         continue
    
    d = mcomp.dims[i]
    birth[d].append(cell_index[i][1])
    if j is not None:
        death[d].append(cell_index[j][1])
    else:
        death[d].append(cell_index[i][1])
        
print("Complete")

In [None]:
for d in range(3):
    
    fig = plt.figure(figsize=(8,8))
    
    ax1 = fig.add_subplot(1,1,1)

    ax1.scatter(birth[d], death[d], marker='.')
    
    ax1.plot(np.linspace(np.min(heights), np.max(heights), 100), np.linspace(np.min(heights), np.max(heights), 100), 'k--')
    
    ax1.set_title(r"$d={}$".format(d))
    
    ax1.set_xlabel(r"Birth [height]")
    ax1.set_ylabel(r"Death [height]")
        
    plt.tight_layout()

    plt.show()


In [None]:
for d in range(2):

    fig = plt.figure(figsize=(8,8))
    
    ax1 = fig.add_subplot(1,1,1)

    ax1.set_yscale('log')
    
    ax1.scatter(birth[d], np.array(death[d])-np.array(birth[d]), marker='.')
            
    ax1.set_title(r"$d={}$".format(d))

    ax1.set_xlabel(r"Birth [height]")
    ax1.set_ylabel(r"Persistence [height]")
        
    plt.tight_layout()

    plt.show()

In [None]:
persistence = []
area = []

for (i, j) in pairs:
    if j is None or mcomp.dims[i] != 0:
        continue
       
    
    persistence.append(cell_index[j][1] - cell_index[i][1])   
    segment = homology.find_segment(i, j, basins, cell_index, mcomp)     
    area.append(len(segment))
                
g = sns.JointGrid(np.log10(persistence), np.log10(area))

# g.plot_joint(plt.hexbin, cmap="Blues", gridsize=16, bins='log')
g.plot_joint(plt.scatter, marker='.', s=8, color='b')

g.plot_marginals(sns.distplot, kde=False)

ax = g.ax_joint
ax.set_xlabel(r"$\log_{10}$Persistence [height]")
ax.set_ylabel(r'$\log_{10}$Area [pixels$^2$]')


plt.show()

persistence = []
length = []

for (i, j) in pairs:
    if j is None or mcomp.dims[i] != 1:
        continue
    
    cycle = homology.find_cycle(bcycles[i], mcomp, V, coV, comp.facets, comp.cofacets)
        
    persistence.append(cell_index[j][1] - cell_index[i][1])   
    length.append(len(cycle))
    
g = sns.JointGrid(np.log10(persistence), np.log10(length))

# g.plot_joint(plt.hexbin, cmap="Blues", gridsize=16, bins='log')
g.plot_joint(plt.scatter, marker='.', s=8, color='b')

g.plot_marginals(sns.distplot, kde=False)

ax = g.ax_joint
ax.set_xlabel(r"$\log_{10}$Persistence [height]")
ax.set_ylabel(r'$\log_{10}$Length [pixels]')


plt.show()

In [None]:
%autoreload

persistence = []
for (i, j) in pairs:
    if j is not None:
        
        persistence.append((cell_index[j][1] - cell_index[i][1], (i, j)))

persistence = sorted(persistence, reverse=True)

print(len(persistence))


palette = it.cycle(sns.color_palette("deep"))
image = np.ones([data.shape[0]*data.shape[1], 3])
image[:] = (0.5, 0.5, 0.5)


cycle_skeleton = set()
minima = set()
for n in range(43):

    (p, (i, j)) = persistence[n]
     
#     print(p, mcomp.dims[i])
        
    if mcomp.dims[i] == 0:
        
        color = next(palette)

        segment = homology.find_segment(i, j, basins, cell_index, mcomp)
        image[list(segment)] = color
        
        minima.add(i)
        
    elif mcomp.dims[i] == 1:
        cycle = homology.find_cycle(bcycles[i], mcomp, V, coV, comp.facets, comp.cofacets)
        cycle_skeleton.update(cycle)
        
image[list(cycle_skeleton)] = (0.0, 0.0, 0.0)
image[list(minima)] = (1.0, 1.0, 1.0)

fig, ax = plt.subplots(figsize=(16, 16))
ax.imshow(image.reshape((data.shape[0], data.shape[1], 3)))


ax.axis('off')
plt.show()
        
        

In [None]:
%autoreload

born = []
for (i, j) in pairs:
    born.append((cell_index[i][1], (i, j)))
                
born = sorted(born)

print(len(born))

print(born)

threshold = 43
palette = it.cycle(sns.color_palette("deep"))
image = np.ones([data.shape[0]*data.shape[1], 3])



cycle_skeleton = set()
minima = set()

last_cycle = set()
last_minimum = 0
for n in range(threshold):

    (b, (i, j)) = born[n]
     
    print(b, mcomp.dims[i])
        
    if mcomp.dims[i] == 0:
        
        color = next(palette)
        image[list(basins[i])] = color
        
        minima.add(i)
        last_minimum = i
        
    elif mcomp.dims[i] == 1:
        cycle = homology.find_cycle(bcycles[i], mcomp, V, coV, comp.facets, comp.cofacets)
        cycle_skeleton.update(cycle)
        
        last_cycle = cycle
        
image[list(cycle_skeleton)] = (0.0, 0.0, 0.0)


# image[last_minimum] = (1.0, 1.0, 1.0)
h = "fec44f"
image[list(last_cycle)] = tuple(int(h[i:i+2], 16) / 256.0 for i in (0, 2 ,4))

image[list(minima)] = (1.0, 1.0, 1.0)

image[np.where(heights > born[threshold-1][0])[0]] = (0.5, 0.5, 0.5)

fig, ax = plt.subplots(figsize=(16, 16))
ax.imshow(image.reshape((data.shape[0], data.shape[1], 3)))


ax.axis('off')
plt.show()
        
        

In [None]:
%autoreload

for i in bcycles:
    print(mcomp.dims[i], i, bcycles[i])
    
    test_cycle = homology.find_cycle([22889], mcomp, V, coV, comp.facets, comp.cofacets)
    
#     for j in bcycles[i]:
#         print(mcomp.facets[j])

    print(test_cycle)
    
    break
    
print(mcomp.facets[22889])
    

test_cycle = set()
print("a")    
for (a, b, c) in homology.find_connections(22889, 1385, V, coV, comp.facets, comp.cofacets):
    test_cycle.add(b)
    print(a, b, c)
    
# # print("b")    
# # for (a, b, c) in homology.find_connections(10314, 319, V, coV, comp.facets, comp.cofacets):
# #     test_cycle.add(b)
# #     print(a, b, c)
    
# # print("c")    
# # for (a, b, c) in homology.find_connections(20118, 18, V, coV, comp.facets, comp.cofacets):
# #     test_cycle.add(b)
# #     print(a, b, c)
    
# # print("d")    
# # for (a, b, c) in homology.find_connections(20118, 319, V, coV, comp.facets, comp.cofacets):
# #     test_cycle.add(b)
# #     print(a, b, c)

In [None]:
palette = it.cycle(sns.color_palette("deep"))
image = np.ones([data.shape[0]*data.shape[1], 3])
image[:] = (0.5, 0.5, 0.5)
        
color = next(palette)
image[list(basins[18])] = color


color = next(palette)
image[list(basins[319])] = color

# image[list(skeleton)] = (0.0, 0.0, 0.0)
image[list(test_cycle)] = (0.0, 0.0, 0.0)

image[18] = (1.0, 1.0, 1.0)
image[319] = (1.0, 1.0, 1.0)

fig, ax = plt.subplots(figsize=(16, 16))
ax.imshow(image.reshape((data.shape[0], data.shape[1], 3)))


ax.axis('off')
plt.show()
        