In [None]:
%load_ext autoreload

import sys, os
sys.path.insert(0, '../')
sys.path.insert(0, '../python_src/')


import numpy as np
import numpy.linalg as la
import scipy as sp
import scipy.io as spio

import matplotlib.pyplot as plt; plt.rcdefaults()
import matplotlib as mpl
import seaborn as sns
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import matplotlib.image as mpimg


import itertools as it
import queue
import networkx as nx
import time

from skimage import filters as skifilters
from skimage import color as skicolor

import homology

mpl.rcParams['mathtext.fontset'] = 'cm'
sns.set_context('poster', font_scale=1.25)
sns.set(color_codes=True, palette='deep')
sns.set_style('ticks', {'xtick.direction': 'in','ytick.direction': 'in', 'axes.linewidth': 2.0})

In [None]:
# mat = spio.loadmat("../sample_data/Z17.mat")
# # data = mat['Z17'][2400:2475, 2400:2475][52:58, 3:8]
# data = mat['Z17'][1500:2000, 1500:2000]

# data = mpimg.imread('2D6185C9-1DD8-B71C-076FF6978423E103-large.jpg')[:,:,:3]
# data = mpimg.imread('20170105_100151.jpg')[:,:,:3]
# # data = mpimg.imread('2018-01-12.png')[:,:,:3]

# data = skicolor.rgb2gray(data)
# data = skifilters.laplace(data)

# data = np.zeros([100, 100], float)
# data[50, 50] = -1

# print(np.min(data), np.max(data))


# data = np.array([[10, 0, 10],
#                 [2, 1, 2],                        
#                 [3, 6, 3],
#                 [4, 5, 4]])

# data = np.array([[3,2,2],
#                 [3,1,1],
#                 [0,1,0]])

# data = np.array([[8, 3, 2],
#                 [7, 8, 1],
#                 [6, 8, 1]
#                 [5, 10, 0]])


print(data.shape)

norm = mpl.colors.Normalize(vmin=np.min(data),vmax=np.max(data))
smap = cm.ScalarMappable(norm=norm, cmap=plt.cm.Greys_r)

fig, ax = plt.subplots(figsize=(16,16))
im = ax.imshow(smap.to_rgba(data))

ax.axis('off')
# plt.colorbar(im)
plt.show()

In [None]:
%autoreload

print("Constructing complex")
comp = homology.construct_cubical_complex(data.shape, oriented=False)

# print(comp.facets)
# print(comp.dims)

homology.check_boundary_op(comp)

print("Constructing cofacets")
comp.construct_cofacets()

# print(comp.cofacets)

print("Constructing discrete gradient")

vertex_time = data.flatten()

vertex_order = homology.construct_vertex_filtration_order(comp, vertex_time, euclidean=False, positions=None)

# fig, ax = plt.subplots(figsize=(16,16))
# im = ax.imshow(vertex_order.reshape(data.shape), cmap=plt.cm.Greys_r)

# ax.axis('off')
# # plt.colorbar(im)
# plt.show()


V = homology.construct_discrete_gradient(comp, vertex_order)

# print(V)


palette = it.cycle(sns.color_palette("deep"))
# palette = it.cycle(['Blues_r', 'Greens_r', 'Purples_r', 'Oranges_r', 'RdPu_r'])

basin_color_map = {}

n = 0
for v in V:
    if v == V[v]:
        n += 1
        if comp.dims[v] == 0:
            basin_color_map[v] = next(palette)
            
        
print("Number Critical Cells:", n)

print("Reversing gradient")
coV = homology.reverse_discrete_gradient(V)

# print(coV)

print("Finding Insertion Times")

time_insert = homology.construct_time_of_insertion_map(comp, vertex_time, vertex_order)


print("Complete")

In [None]:
%autoreload

print("Simplifying Morse Complex")

V, coV = homology.simplify_morse_complex(3e-1, V, coV, comp, time_insert)

n = 0
for v in V:
    if v == V[v]:
#         print(v, comp.dims[v])
        n += 1


In [None]:
%autoreload

print("Finding basins")
basins = homology.find_basins(coV, comp.cofacets, comp.dims, 0)
# print(basins)

print("Calculating Morse complex")
mcomp = homology.construct_morse_complex(V, comp.facets, comp, oriented=False)

print("Checking Boundary Operator")
homology.check_boundary_op(comp)

mcomp.construct_cofacets()

# print(mcomp.facets)
# print(mcomp.dims)

# print(mcomp.cofacets)

print("Calculating Morse skeleton")
skeleton = homology.find_morse_skeleton(mcomp, comp, V, coV, 1)

print("Complete")

In [None]:
%autoreload

image = np.ones([data.shape[0]*data.shape[1], 3])

for i in basins:
    
    image[list(basins[i])] = basin_color_map[i]
    
# image = np.zeros([data.shape[0]*data.shape[1], 4])    
    

# norm = mpl.colors.Normalize(vmin=np.min(data), vmax=np.max(data))

# flat = data.flatten()

# for i in basins:
#     smap = cm.ScalarMappable(norm=norm, cmap=basin_color_map[i])
#     image[list(basins[i])] = smap.to_rgba(flat[list(basins[i])])
    
image[list(skeleton)] = (0.0, 0.0, 0.0)

for i in basins:
    image[i] = (1.0, 1.0, 1.0)
        
print(np.min(data), np.max(data))

fig, ax = plt.subplots(figsize=(16, 16))
ax.imshow(image.reshape((data.shape[0], data.shape[1], 3)))


ax.axis('off')
plt.show()

In [None]:
%autoreload

print("Calculating persistence pairs...")

filtration = homology.construct_filtration(mcomp, time_insert)

# weights = homology.get_morse_weights(mcomp, V, coV, comp.facets, comp.cofacets)
   
# print("Weights:", weights)

start = time.time()

(pairs, cell_index) = homology.compute_persistence(mcomp, filtration, show_zero=True)
# (pairs, cell_index, bcycles) = homology.compute_persistence(mcomp, filtration, show_zero=True, 
#                                                             birth_cycles=True, optimal_cycles=False)
# (pairs, cell_index, bcycles, ocycles) = homology.compute_persistence(mcomp, filtration, show_zero=False, 
#                                                             birth_cycles=True, optimal_cycles=True,
#                                                                     weights=weights, relative_cycles=True)
end = time.time()

print("Elapsed Time:", end - start)

# print("Pairs:", pairs)
# print("Birth Cycles:", bcycles)
# print("Death Cycles:", ocycles)

birth = [[] for i in range(mcomp.dim+1)]
death = [[] for i in range(mcomp.dim+1)]

for (i, j) in pairs:
    if j is None:
        continue
    if cell_index[i][0] == cell_index[j][0]:
        continue
    
    d = mcomp.dims[i]
    birth[d].append(cell_index[i][1])
    if j is not None:
        death[d].append(cell_index[j][1])
    else:
        death[d].append(cell_index[i][1])
        
print("Complete")

In [None]:
# for d in range(2):
    
#     fig = plt.figure(figsize=(8,8))
    
#     ax1 = fig.add_subplot(1,1,1)

#     ax1.scatter(birth[d], death[d], marker='.')
    
#     ax1.plot(np.linspace(np.min(vertex_time), np.max(vertex_time), 100), np.linspace(np.min(vertex_time), np.max(vertex_time), 100), 'k--')
    
#     ax1.set_title(r"$d={}$".format(d))
    
#     ax1.set_xlabel(r"Birth [height]")
#     ax1.set_ylabel(r"Death [height]")
        
#     plt.tight_layout()

#     plt.show()
    
    
fig = plt.figure(figsize=(8,8))
    
ax1 = fig.add_subplot(1,1,1)

ax1.scatter(birth[0], death[0], marker='.', color='b')

ax1.scatter(death[1], birth[1], marker='.', color='g')

ax1.plot(np.linspace(np.min(vertex_time), np.max(vertex_time), 100), np.linspace(np.min(vertex_time), np.max(vertex_time), 100), 'k--')

ax1.set_title(r"$d={}$".format(0))

ax1.set_xlabel(r"Birth [height]")
ax1.set_ylabel(r"Death [height]")

plt.tight_layout()

plt.show()



In [None]:
# for d in range(2):

#     fig = plt.figure(figsize=(8,8))
    
#     ax1 = fig.add_subplot(1,1,1)

#     ax1.set_yscale('log')
    
#     ax1.scatter(birth[d], np.array(death[d])-np.array(birth[d]), marker='.')
            
#     ax1.set_title(r"$d={}$".format(d))

#     ax1.set_xlabel(r"Birth [height]")
#     ax1.set_ylabel(r"Persistence [height]")
#     ax1.set_ylim(1e-6, 1.0)
        
#     plt.tight_layout()

#     plt.show()

fig = plt.figure(figsize=(8,8))
    
ax1 = fig.add_subplot(1,1,1)

ax1.set_yscale('log')

ax1.scatter(birth[0], np.array(death[0])-np.array(birth[0]), marker='.', color='b')
ax1.scatter(birth[1], np.array(death[1])-np.array(birth[1]), marker='.', color='g')

ax1.set_xlabel(r"Birth [height]")
ax1.set_ylabel(r"Persistence [height]")
ax1.set_ylim(1e-6, 1e3)

# ax1.hlines(10**(-1.9), -0.4, 1.0)

plt.tight_layout()

plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8,8))

h = np.log10(np.concatenate((np.array(death[0])-np.array(birth[0]), np.array(death[1])-np.array(birth[1]))))

sns.distplot(h, ax=ax, kde=True, norm_hist=True)


ax.set_ylim(1e-6, 5e0)
ax.set_yscale('log')

ax.vlines(-1.9, 1e-6, 5e0)

plt.show()

In [None]:
persistence = []
area = []

for (i, j) in pairs:
    
#     if j is not None and cell_index[j][0] - cell_index[i][0] == 0:
#         print(i, j, cell_index[j][1] - cell_index[i][1])
        
    if j is None or mcomp.dims[i] != 0 or cell_index[j][0] - cell_index[i][0] == 0:
        continue
       
    
    persistence.append(cell_index[j][1] - cell_index[i][1])   
    segment = homology.find_segment(i, j, basins, cell_index, mcomp)     
    area.append(len(segment))
                
g = sns.JointGrid(np.log10(persistence), np.log10(area))

# g.plot_joint(plt.hexbin, cmap="Blues", gridsize=16, bins='log')
g.plot_joint(plt.scatter, marker='.', s=8, color='b')

g.plot_marginals(sns.distplot, kde=False)

ax = g.ax_joint
ax.set_xlabel(r"$\log_{10}$Persistence [height]")
ax.set_ylabel(r'$\log_{10}$Area [pixels$^2$]')


plt.show()

persistence = []
length = []

for (i, j) in pairs:
    
#     if j is not None and cell_index[j][0] - cell_index[i][0] == 0:
#         print(i, j, cell_index[j][1] - cell_index[i][1])
    
    if j is None or mcomp.dims[i] != 1 or cell_index[j][0] - cell_index[i][0] == 0:
        continue
    
    dcells = homology.expand_death_cycle(i, j, cell_index, mcomp)
    cells = homology.convert_morse_to_complex(dcells, mcomp, comp, V, coV)
    verts = homology.get_vertices(cells, comp)
        
    persistence.append(cell_index[j][1] - cell_index[i][1])   
    length.append(len(verts))
    
g = sns.JointGrid(np.log10(persistence), np.log10(length))

# g.plot_joint(plt.hexbin, cmap="Blues", gridsize=16, bins='log')
g.plot_joint(plt.scatter, marker='.', s=8, color='b')

g.plot_marginals(sns.distplot, kde=False)

ax = g.ax_joint
ax.set_xlabel(r"$\log_{10}$Persistence [height]")
ax.set_ylabel(r'$\log_{10}$Area [pixels$^2$]')


plt.show()

In [None]:
%autoreload

def plot_features(N_features, sorted_pairs):
    
#     palette = it.cycle(sns.color_palette("deep"))

    palette = it.cycle(['Blues_r', 'Greens_r', 'Purples_r', 'Reds_r'])

    norm = mpl.colors.Normalize(vmin=np.min(data), vmax=np.max(data))
    smap = cm.ScalarMappable(norm=norm, cmap=plt.cm.Greys_r)

    image = smap.to_rgba(data).reshape((data.shape[0]*data.shape[1], 4))

    flat = data.flatten()


    cycle_skeleton = set()
    minima = set()

    last_cycle = set()
    last_segment = set()
    
    last_dcells = set()
    last_cycle_j = 0
        
    for n in range(N_features):
        
        

        (p, (i, j)) = sorted_pairs[n]
        
#         if j is None:
#             continue

#         print(sorted_pairs[n])

        if mcomp.dims[i] == 0:

            segment = list(homology.find_segment(i, j, basins, cell_index, mcomp))
#             color = list(next(palette))
#             color.append(1)
#             image[list(segment)] = color

            
            smap = cm.ScalarMappable(norm=norm, cmap=next(palette))
            image[segment] = smap.to_rgba(flat[segment])
        
#             for v in segment:
#                 image[v] = smap.to_rgba(flat[v])

#             color = next(palette)
#             image[list(basins[i])] = color

            minima.add(i)
            last_segment = segment
            
            last_basin = i

        elif mcomp.dims[i] == 1:
            
            dcells = homology.expand_death_cycle(i, j, cell_index, mcomp)
            dbound = homology.get_boundary_cycle(dcells, mcomp)
            dcycle = homology.convert_morse_to_complex(dbound, mcomp, comp, V, coV)
            cycle = homology.get_vertices(dcycle, comp)

            cycle_skeleton.update(cycle)

            last_dcells = dcells
            last_cycle = cycle            
            last_cycle_j = j

    print(sorted_pairs[N_features-1], mcomp.dims[sorted_pairs[N_features-1][1][0]])
            
    if mcomp.dims[sorted_pairs[N_features-1][1][0]] == 1:
        image[list(cycle_skeleton)] = (0.0, 0.0, 0.0, 1.0)
                
        cells = homology.convert_morse_to_complex(last_dcells, mcomp, comp, V, coV)
        verts = homology.get_vertices(cells, comp)
        
#         h = "c51b8a"
#         color = list(int(h[i:i+2], 16) / 256.0 for i in (0, 2 ,4))
#         color.append(1)
#         image[list(verts)] = color
        
        
        h = "ff7f00"
        color = list(int(h[i:i+2], 16) / 256.0 for i in (0, 2 ,4))
        color.append(1)
        image[list(last_cycle)] = color
        
        
#         h = "f03b20"
#         color = tuple(int(h[i:i+2], 16) / 256.0 for i in (0, 2 ,4))
#         image[list(set(verts) & set(last_cycle))] = color
                        
    else:
        

        smap = cm.ScalarMappable(norm=norm, cmap='Oranges_r')
#         for v in last_segment:
#             image[v] = smap.to_rgba(flat[v])
        
        image[segment] = smap.to_rgba(flat[last_segment])
        
#         h = "ff7f00"
#         color = list(int(h[i:i+2], 16) / 256.0 for i in (0, 2 ,4))
#         color.append(1)
#         image[list(last_segment)] = color

        image[list(cycle_skeleton)] = (0.0, 0.0, 0.0, 1.0)
    
   
    # image[list(minima)] = (1.0, 1.0, 1.0)

    
    
    fig, ax = plt.subplots(figsize=(16, 16))
    
    ax.imshow(image.reshape((data.shape[0], data.shape[1], 4)))


    ax.axis('off')
    plt.show()
    
persistence = []
for (i, j) in pairs:
    if j is not None:
        persistence.append((cell_index[j][1] - cell_index[i][1], (i, j)))

persistence = sorted(persistence, reverse=True)

print(len(persistence))

for n in range(len(persistence)):
    
    print(n)
    
    if persistence[n][0] < 0.1 or n >= 20:
        break
    
    plot_features(n+1, persistence)
        

In [None]:
%autoreload

palette = it.cycle(sns.color_palette("deep"))
# palette = it.cycle(sns.color_palette("Blues"))

image = np.ones([data.shape[0]*data.shape[1], 3])

for i in basins:

    color = next(palette)
    image[list(basins[i])] = color

    
print("Calculating Morse skeleton")
skeleton = homology.find_morse_skeleton(mcomp, V, coV, comp.facets, comp.cofacets, comp.dims, 1)
image[list(skeleton)] = (0.0, 0.0, 0.0)

for i in basins:
    image[i] = (1.0, 1.0, 1.0)
    
    
dcycle = homology.expand_death_cycle(1440, 2034, cell_index, mcomp)
        
print(dcycle)
   
face = homology.convert_morse_to_complex({2034}, mcomp, comp, V, coV)
print(face)

verts = set()
for i in face:
    for j in comp.facets[i]:
        verts.update(comp.facets[j])
        
image[list(verts)] = (0.5, 0.5, 0.5)
    
    
cycle = homology.convert_morse_to_complex(dcycle, mcomp, comp, V, coV)
verts = set()
for i in cycle:
    verts.update(comp.facets[i])
image[list(verts)] = (1.0, 1.0, 1.0)





# image[np.where(heights > cell_index[1440][1])[0]] = (0.5, 0.5, 0.5)
    
fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(image.reshape((data.shape[0], data.shape[1], 3)))


ax.axis('off')
plt.show()

In [None]:
%autoreload

def plot_features(N_features, sorted_pairs):
    
    palette = it.cycle(sns.color_palette("deep"))
    image = np.ones([data.shape[0]*data.shape[1], 3])
#     image = 0.5*np.ones([data.shape[0]*data.shape[1], 3])



    cycle_skeleton = set()
    minima = set()

    
    last_basin = 0
    
    last_cycle = set()
    last_cycle_j = 0
    last_dcells = set()
        
    for n in range(N_features):
        
        

        (p, (i, j)) = sorted_pairs[n]
        
#         if j is None:
#             continue

#         print(i, j)

#         print(sorted_pairs[n])

        if mcomp.dims[i] == 0:

            color = next(palette)
            image[list(basins[i])] = color

            minima.add(i)
            
            last_basin = i

        elif mcomp.dims[i] == 1:
            
            dcells = homology.expand_death_cycle(i, j, cell_index, mcomp)
            dbound = homology.get_boundary_cycle(dcells, mcomp)
            dcycle = homology.convert_morse_to_complex(dbound, mcomp, comp, V, coV)
            cycle = homology.get_vertices(dcycle, comp)

            cycle_skeleton.update(cycle)

            last_dcells = dcells
            last_cycle = cycle            
            last_cycle_j = j
            
            
    
    
    
    print(sorted_pairs[N_features-1], mcomp.dims[sorted_pairs[N_features-1][1][0]])
    
    if mcomp.dims[sorted_pairs[N_features-1][1][0]] == 1:
        image[list(cycle_skeleton)] = (0.0, 0.0, 0.0)
        
        cells = homology.convert_morse_to_complex(last_dcells, mcomp, comp, V, coV)
        verts = homology.get_vertices(cells, comp)
        
        h = "c51b8a"
        color = tuple(int(h[i:i+2], 16) / 256.0 for i in (0, 2 ,4))
        image[list(verts)] = color
        
        
        h = "ff7f00"
        color = tuple(int(h[i:i+2], 16) / 256.0 for i in (0, 2 ,4))
        
        image[list(last_cycle)] = color
                        
    else:

        h = "ff7f00"
        color = tuple(int(h[i:i+2], 16) / 256.0 for i in (0, 2 ,4))

        image[list(basins[last_basin])] = color
        image[list(cycle_skeleton)] = (0.0, 0.0, 0.0)
        

    image[list(minima)] = (1.0, 1.0, 1.0)

    image[np.where(heights > born[N_features-1][0])[0]] = (0.5, 0.5, 0.5)

    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(image.reshape((data.shape[0], data.shape[1], 3)))


    ax.axis('off')
    plt.show()

born = []
for (i, j) in pairs:
    born.append((cell_index[i][1], (i, j)))
                
born = sorted(born)

print(born)

print(len(born))

for n in range(len(born)):
    plot_features(n+1, born)
    
    
# persistence = []
# for (i, j) in pairs:
#     if j is not None:
#         persistence.append((cell_index[j][1] - cell_index[i][1], (i, j)))

# persistence = sorted(persistence, reverse=True)

# print(len(persistence))

# for n in range(2):
#     plot_features(n+1, persistence)
        