In [None]:
%load_ext autoreload

import sys, os
sys.path.insert(0, '../')
sys.path.insert(0, '../python_src/')

import numpy as np
import numpy.linalg as la
import scipy as sp
import scipy.io as spio

import matplotlib.pyplot as plt; plt.rcdefaults()
import matplotlib as mpl
import seaborn as sns
import matplotlib.colors as mcolors


import itertools as it
from collections import deque
import networkx as nx
import skimage.morphology as morph

import homology

mpl.rcParams['mathtext.fontset'] = 'cm'
sns.set_context('poster', font_scale=1.25)
sns.set(color_codes=True, palette='deep')
sns.set_style('ticks', {'xtick.direction': 'in','ytick.direction': 'in', 'axes.linewidth': 2.0})

In [None]:
mat = spio.loadmat("../sample_data/Creases15.mat")

data = mat['ridges'][1000:2000, 1000:2000].astype(np.int8)

# data = mat['ridges'][1000:2000, 1000:2000].astype(np.int8)

# data = mat['ridges'].astype(np.int8)

fig, ax = plt.subplots(figsize=(16,16))
ax.imshow(data, cmap=plt.cm.gray)
ax.axis('off')
plt.show()

In [None]:
def construct_pixel_graph(nrows, ncols, pos=False):
    G = nx.Graph()
    
    positions = {}
    
    for i in range(nrows):
        for j in range(ncols):
            G.add_node(ncols*i + j)
            if pos:
                positions[ncols*i + j] = np.array([j, i])
            
            
    for i in range(nrows):
        for j in range(ncols-1):
            G.add_edge(ncols*i + j, ncols*i + j+1)

    for i in range(nrows-1):
        for j in range(ncols):
            G.add_edge(ncols*i + j, ncols*(i+1) + j)
            
    for i in range(nrows-1):
        for j in range(ncols-1):
            G.add_edge(ncols*i + j, ncols*(i+1) + j+1)
            
    for i in range(nrows-1):
        for j in range(ncols-1):
            G.add_edge(ncols*(i+1) + j, ncols*i + j+1)
    
    if pos:
        return (G, positions)
    else:
        return G




In [None]:
def segment_graph(G, heights):
    
    
    
    # find level sets
    level_sets = {}
    for h in np.unique(heights):
        level_sets[h] = list(np.where(heights==h)[0])
     
    # basin 0 is the watersheds
    basins = np.full(G.order(), -1, int)
    current_label = 0
        
    # iterate through each level set
    for h in level_sets:
        
        level = level_sets[h]
        
        print("Level:", h)

        dist = {}
        Q = deque()
        
        # find neighbors in lower level sets
        for vi in level:
            for nbr in G[vi]:
                if basins[nbr] >= 0:
                    if vi not in dist:
                        dist[vi] = 1
                        Q.append(vi)
                    
                    
        unvisited = set(level)
                
        # breadth first search through connected part of level set
        while len(Q) > 0:
                        
            vi = Q.popleft()
            current_dist = dist[vi]
            
            
            for nbr in G[vi]:
                                
                # neighbor has already been assigned to a basin or watershed
                # and has a different distance (or no distance defined at all)
                if basins[nbr] >= 0 and (nbr not in dist or dist[nbr] < current_dist):
                    
                    
                    # neighbor in different basin
                    if basins[nbr] > 0:
                        
                        # current node not visited yet, or was assigned to watershed
                        if vi in unvisited or basins[vi] == 0:
                            # assign current node to neighbor's basin
                            basins[vi] = basins[nbr]
                            unvisited.discard(vi)
                        
                        # current node already assigned to basin
                        # and its basin is different from its neighbor's
                        elif basins[vi] != basins[nbr]:
                            # assign current node to watershed
                            basins[vi] = 0
                            unvisited.discard(vi)
                        
                    # neighbor is a watershed and current node is unvisited
                    elif vi in unvisited:
                        # assign current node to watershed
                        basins[vi] = 0
                        unvisited.discard(vi)
                
                # neighbor has not been visited and not been given a distance
                elif nbr in unvisited and nbr not in dist:
                    # append to queue
                    dist[nbr] = current_dist + 1
                    Q.append(nbr)
            
            
        # breadth first search through separate components corresponding to new minima
        while len(unvisited) > 0:
                        
            vi = unvisited.pop()
            current_label += 1
            basins[vi] = current_label
            
            Q.append(vi)
            
            while len(Q) > 0:
                vj = Q.popleft()
                
                for nbr in G[vj]:
                    if nbr in unvisited:
                        Q.append(nbr)
                        basins[nbr] = current_label
                        unvisited.discard(nbr)
            
    
    # iterate through each pixel and check if it has neighbor of lower valued basin
    # add to watershed if it does
    for vi in range(G.order()):
        if basins[vi] > 0:
            for nbr in G[vi]:
                if basins[nbr] != 0 and basins[nbr] < basins[vi]:
                    basins[vi] = 0
                    break
    
    segments = {i:set() for i in range(current_label+1)}
    for i in range(G.order()):
        segments[basins[i]].add(i)
        
    return segments


In [None]:
G = construct_pixel_graph(data.shape[0], data.shape[1])
# G, positions = construct_pixel_graph(3, 3)

# nx.draw(G, pos=positions)
# plt.show()

segments = segment_graph(G, -data.flatten())

In [None]:
palette = it.cycle(sns.color_palette("deep"))

image = np.zeros([data.shape[0]*data.shape[1], 3])

for i in segments:
    if i == 0:
        image[list(segments[i])] = (1.0, 1.0, 1.0)
    else:
        color = next(palette)
        image[list(segments[i])] = color

image[np.where(data.flatten() != 0)[0]] = (0.0, 0.0, 0.0)

fig, ax = plt.subplots(figsize=(16,16))
ax.imshow(image.reshape((data.shape[0], data.shape[1], 3)), cmap=plt.cm.gray)
ax.axis('off')
plt.show()

In [None]:
def bfs_dilation(G, ones):
    
    dist = np.full(G.order(), -1, int)
    
    Q = deque()
    
    unvisited = set(np.arange(G.order()))
    
    for i in ones:
        dist[i] = 0
        Q.append(i)
        
        
    while len(Q) > 0:
        
        vi = Q.popleft()
        unvisited.discard(vi)
        
        current_dist = dist[vi]
        
        for nbr in G[vi]:
            
            if nbr in unvisited and dist[nbr] == -1:
                dist[nbr] = current_dist + 1
                Q.append(nbr)
                
    return dist

def euclidean_dilation(ones, positions):
    dist = np.full(len(positions), -1.0, float)
    
    print(len(ones))
    
    for k, i in enumerate(ones):    
        if k % 10 == 0:
            print(k)
        
        
        posi = positions[i]
        for j in np.arange(len(positions)):
            posj = positions[j]
            
            d = la.norm(posi - posj)
            if dist[j] == -1.0 or d < dist[j]:
                dist[j] = d
                
    return dist
            
        
    
            
G, positions = construct_pixel_graph(data.shape[0], data.shape[1], True)

# dist = euclidean_dilation(np.where(data.flatten() == 1)[0], positions)

dist = bfs_dilation(G, np.where(data.flatten() == 1)[0])


In [None]:
fig, ax = plt.subplots(figsize=(16,16))    
ax.imshow(dist.reshape(data.shape), cmap=plt.cm.Blues_r)
ax.axis('off')
plt.show()

In [None]:
segments = segment_graph(G, -dist)

In [None]:
palette = it.cycle(sns.color_palette("deep"))

image = np.zeros([data.shape[0]*data.shape[1], 3])

for i in segments:
    if i == 0:
        image[list(segments[i])] = (1.0, 1.0, 1.0)
    else:
        color = next(palette)
        image[list(segments[i])] = color

image[np.where(data.flatten() == 1)[0]] = (0.0, 0.0, 0.0)

image[np.where(dist > 10 )[0]] = (1.0, 1.0, 1.0)

fig, ax = plt.subplots(figsize=(32,32))
ax.imshow(image.reshape((data.shape[0], data.shape[1], 3)), cmap=plt.cm.gray)
ax.axis('off')
plt.show()

In [None]:
%autoreload

heights = dist

print("Creating Mesh")
comp = homology.mesh_complex(data.shape[0], data.shape[1], False)


# print(data.shape)

# print(comp.faces)

# heights = np.append(heights, np.min(heights)-1)

print("Creating Filtration")
(simp_filt, dims) = homology.construct_lower_star_filtration(comp, np.argsort(heights))


# print(simp_filt)
# print(dims)

print("Calculating Persistence")
(ipairs, hsort, persist, sim_to_pindex) = homology.compute_persistence_pairs(simp_filt, dims, sorted(heights), comp)

print("Persistence Pairs:", ipairs)
print("Infinite Persistence:", persist)

In [None]:
birth = [[] for i in range(3)]
death = [[] for i in range(3)]
mult = [[] for i in range(3)]
for d in range(3):
    for (i, j) in list(ipairs[d].keys()):
        birth[d].append(hsort[i])
        death[d].append(hsort[j])
        mult[d].append(ipairs[d][(i, j)])

pbirth = [[] for i in range(3)]
for d in range(3):
    for i in (persist[d].keys()):
        pbirth[d].append(hsort[i])

# print(birth)
# print(death)

for d in range(3):
    
    fig = plt.figure(figsize=(8,8))
    
    ax1 = fig.add_subplot(1,1,1)

    ax1.scatter(birth[d], death[d], c=mult[d], marker='.', cmap='Blues', edgecolors='k', linewidths=1.0, s=100)
#     ax1.scatter(birth[d], death[d], c='b', marker='.', s=20)
    
    ax1.scatter(pbirth[d], pbirth[d], marker='s', color='r')

#     ax1.plot(np.linspace(0, 125, 100), np.linspace(0, 125, 100), 'k--')
    
    ax1.plot(np.linspace(np.min(heights), np.max(heights), 100), np.linspace(np.min(heights), np.max(heights), 100), 'k--')
    
    ax1.set_title(r"$d={}$".format(d))

#     ax1.set_xscale('symlog', linthreshx=1e-16)
#     ax1.set_yscale('symlog', linthreshy=1e-16)
    
    ax1.set_xlabel(r"birth")
    ax1.set_ylabel(r"death")
        
    plt.tight_layout()

    plt.show()

In [None]:
%autoreload

# data = mat['ridges'][1000:2000, 1000:2000].astype(np.int8) - mat['valleys'][1000:2000, 1000:2000].astype(np.int8)

data = mat['ridges'][1500:2000, 1500:2000].astype(np.int8)

pixels, heights = homology.binary_dilation_filtration(data, True)

# pixels, heights = homology.laplace_dilation_filtration(data, True)

In [None]:
fig, ax = plt.subplots(figsize=(16,16))

norm = mcolors.Normalize(vmin=np.min(heights), vmax=np.max(heights))
cmap = mpl.cm.GnBu_r

im = ax.imshow(np.reshape(np.array(heights)[np.argsort(pixels)][:data.shape[0]*data.shape[1]], data.shape), 
          cmap=cmap, norm=norm)
ax.axis('off')

# plt.colorbar(im)

plt.tight_layout()

plt.show()

In [None]:
%autoreload

print("Creating Mesh")
comp = homology.pixel_mesh(data.shape[0], data.shape[1], True)

# print(data.shape)

# print(comp.faces)

print("Creating Filtration")
(simp_filt, dims) = homology.construct_lower_star_filtration(comp, pixels)

# print(simp_filt)
# print(dims)

print("Calculating Persistence")
(ipairs, hsort, persist, sim_to_pindex) = homology.compute_persistence_pairs(simp_filt, dims, heights, comp)

# print("Persistence Pairs:", ipairs)
print("Infinite Persistence:", persist)

In [None]:
birth = [[] for i in range(3)]
death = [[] for i in range(3)]
mult = [[] for i in range(3)]
for d in range(3):
    for (i, j) in list(ipairs[d].keys()):
        birth[d].append(hsort[i])
        death[d].append(hsort[j])
        mult[d].append(ipairs[d][(i, j)])

pbirth = [[] for i in range(3)]
for d in range(3):
    for i in (persist[d].keys()):
        pbirth[d].append(hsort[i])

# print(birth)
# print(death)

for d in range(3):
    
    fig = plt.figure(figsize=(8,8))
    
    ax1 = fig.add_subplot(1,1,1)

    ax1.scatter(birth[d], death[d], c=mult[d], marker='.', cmap='Blues', edgecolors='k', linewidths=1.0, s=100)
#     ax1.scatter(birth[d], death[d], c='b', marker='.', s=20)
    
    ax1.scatter(pbirth[d], pbirth[d], marker='s', color='r')

    ax1.plot(np.linspace(0, 125, 100), np.linspace(0, 125, 100), 'k--')
    
#     ax1.plot(np.linspace(np.min(heights), np.max(heights), 100), np.linspace(np.min(heights), np.max(heights), 100), 'k--')
    
    ax1.set_title(r"$d={}$".format(d))

#     ax1.set_xscale('symlog', linthreshx=1e-16)
#     ax1.set_yscale('symlog', linthreshy=1e-16)
    
    ax1.set_xlabel(r"birth")
    ax1.set_ylabel(r"death")
        
    plt.tight_layout()

    plt.show()