In [None]:
%load_ext autoreload

import sys, os
sys.path.insert(0, '../')
sys.path.insert(0, '../python_src/')


import numpy as np
import numpy.linalg as la
import scipy as sp
import scipy.io as spio

import matplotlib.pyplot as plt; plt.rcdefaults()
import matplotlib as mpl
import seaborn as sns
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import matplotlib.image as mpimg

import itertools as it
import collections as co
import queue
import networkx as nx
import time
import pickle

from skimage import filters as skifilters
from skimage import color as skicolor

from sklearn import linear_model

import homology

mpl.rcParams['mathtext.fontset'] = 'cm'
sns.set_context('poster', font_scale=1.25)
sns.set(color_codes=True, palette='deep')
sns.set_style('ticks', {'xtick.direction': 'in','ytick.direction': 'in', 'axes.linewidth': 2.0})



def total_size(o, handlers={}, verbose=False):
    """ Returns the approximate memory footprint an object and all of its contents.

    Automatically finds the contents of the following builtin containers and
    their subclasses:  tuple, list, deque, dict, set and frozenset.
    To search other containers, add handlers to iterate over their contents:

        handlers = {SomeContainerClass: iter,
                    OtherContainerClass: OtherContainerClass.get_elements}

    """
    dict_handler = lambda d: it.chain.from_iterable(d.items())
    all_handlers = {tuple: iter,
                    list: iter,
                    co.deque: iter,
                    dict: dict_handler,
                    set: iter,
                    frozenset: iter,
                   }
    all_handlers.update(handlers)     # user handlers take precedence
    seen = set()                      # track which object id's have already been seen
    default_size = sys.getsizeof(0)       # estimate sizeof object without __sizeof__

    def sizeof(o):
        if id(o) in seen:       # do not double count the same object
            return 0
        seen.add(id(o))
        s = sys.getsizeof(o, default_size)

        if verbose:
            print(s, type(o), repr(o), file=sys.stderr)

        for typ, handler in all_handlers.items():
            if isinstance(o, typ):
                s += sum(map(sizeof, handler(o)))
                break
                
                
        if not hasattr(o.__class__, '__slots__'):
            if hasattr(o, '__dict__'):
                s+=sizeof(o.__dict__) # no __slots__ *usually* means a __dict__, but some special builtin classes (such as `type(None)`) have neither
            # else, `o` has no attributes at all, so sys.getsizeof() actually returned the correct value
        else:
            s+=sum(sizeof(getattr(o, x)) for x in o.__class__.__slots__ if hasattr(o, x))
                   
        
        return s
    
    return sizeof(o)

In [None]:
label = 'Z16'

# mat = spio.loadmat("../sample_data/{}.mat".format(label))
# data = mat[label][500:2500, 500:2500]

mat = spio.loadmat("Everest.mat")
data = mat['Expression1']

print(total_size(mat))

# dx = 5280 * data.shape[0] / 20
# dy = 5280 * data.shape[1] / 20

# data = mpimg.imread('2D6185C9-1DD8-B71C-076FF6978423E103-large.jpg')[:,:,:3]
# data = mpimg.imread('20170105_100151.jpg')[:,:,:3]
# # data = mpimg.imread('2018-01-12.png')[:,:,:3]

# data = skicolor.rgb2gray(data)
# data = skifilters.laplace(data)

# data = np.zeros([100, 100], float)
# data[50, 50] = -1

# print(np.min(data), np.max(data))


# data = np.array([[10, 0, 10],
#                 [2, 1, 2],                        
#                 [3, 6, 3],
#                 [4, 5, 4]])

# data = np.array([[3,2,2],
#                 [3,1,1],
#                 [0,1,0]])

# data = np.array([[8, 3, 2],
#                 [7, 8, 1],
#                 [6, 9, 1],
#                 [5, 10, 0]])


print(data.shape)

ls = mcolors.LightSource()

# norm = mpl.colors.Normalize(vmin=np.min(data),vmax=np.max(data))
norm = mpl.colors.Normalize(vmin=0,vmax=1)
smap = cm.ScalarMappable(norm=norm, cmap=plt.cm.Greys_r)

fig, ax = plt.subplots(figsize=(16,16))
# im = ax.imshow(smap.to_rgba(data))

image = smap.to_rgba(ls.hillshade(data, vert_exag=100, dx=1, dy=1)).reshape([data.shape[0]*data.shape[1], 4])

h = "ff7f00"
color = list(int(h[i:i+2], 16) / 256.0 for i in (0, 2 ,4))
color.append(1)
image[np.where(np.isnan(data.flatten()))[0]] = color
im = ax.imshow(image.reshape([data.shape[0], data.shape[1], 4]))


# ax.hlines(np.linspace(int(data.shape[0]/3), 2*int(data.shape[0]/3), 2), 0, data.shape[1])
# ax.vlines(np.linspace(int(data.shape[1]/3), 2*int(data.shape[1]/3), 2), 0, data.shape[0])


ax.axis('off')
plt.show()

# X = 1
# Y = 1

# data = data[Y*int(data.shape[0]/2):(Y+1)*int(data.shape[0]/2), 
#             X*int(data.shape[1]/2):(X+1)*int(data.shape[1]/2)]

# print(data.shape)

# fig, ax = plt.subplots(figsize=(16,16))
# im = ax.imshow(smap.to_rgba(ls.hillshade(data, vert_exag=100, dx=1, dy=1)))

# ax.axis('off')
# plt.show()
# print(total_size(data))


In [None]:
%autoreload

dual = True

start = time.time()

print("Constructing complex")
comp = homology.construct_cubical_complex(data.shape, oriented=False, dual=dual)

# for c in comp.get_cells():
#     print(c, comp.get_dim(c), comp.get_facets(c))

print("Checking boundary operator")
print(homology.check_boundary_op(comp))

print("Constructing cofacets")
comp.construct_cofacets()

# for c in comp.get_cells():
#     print(c, comp.get_dim(c), comp.get_cofacets(c))

print(total_size(comp))

end = time.time()

print("Elapsed Time:", end - start)

start = time.time()

print("Finding vertex order")

vertex_time = data.flatten()
vertex_order = homology.construct_vertex_filtration_order(comp, vertex_time, euclidean=False, positions=None)

print(total_size(vertex_time))
print(total_size(vertex_order))

end = time.time()

print("Elapsed Time:", end - start)

# print(vertex_order)

# # fig, ax = plt.subplots(figsize=(16,16))
# # im = ax.imshow(vertex_order.reshape(data.shape), cmap=plt.cm.Greys_r)

# # ax.axis('off')
# # # plt.colorbar(im)
# # plt.show()

start = time.time()

print("Finding Insertion Times")

insert_order = homology.construct_time_of_insertion_map(comp, vertex_time, vertex_order, dual=dual)

print(total_size(insert_order))

# print(insert_order)

end = time.time()

print("Elapsed Time:", end - start)

start = time.time()

print("Constructing discrete gradient")

V = homology.construct_discrete_gradient(comp, insert_order, dual=dual)

print(total_size(V))

# print(V)

end = time.time()

print("Elapsed Time:", end - start)

n = 0
for v in range(len(V)):
#     print(v, V[v])
    
    if V[v] == v:
        n += 1
        

print("Number Critical Cells:", n)

print("Reversing gradient")
coV = homology.reverse_discrete_gradient(V)

print(total_size(coV))

# # print(V)
# # print(coV)

start = time.time()

print("Calculating Morse complex")
mcomp = homology.construct_morse_complex(V, comp, oriented=False)

print("Checking boundary operator")
print(homology.check_boundary_op(mcomp))

print("Constructing cofacets")
mcomp.construct_cofacets()



# for c in mcomp.get_cells():
#     print(c, mcomp.get_dim(c), mcomp.get_facets(c), mcomp.get_coeffs(c))
# for c in mcomp.get_cells():
#     print(c, mcomp.get_dim(c), mcomp.get_cofacets(c))

print(total_size(mcomp))

end = time.time()
print("Elapsed Time:", end - start)

start = time.time()

print("Finding basins")
basins = homology.find_basins(mcomp, coV, comp, insert_order, dual=dual)
# print(basins)

print(total_size(basins))

print("Calculating Morse skeleton")
skeleton = homology.find_morse_skeleton(mcomp, V, comp, 1, insert_order, dual=dual)

print(total_size(skeleton))

end = time.time()
print("Elapsed Time:", end - start)


palette1 = it.cycle(sns.color_palette("deep"))
palette2 = it.cycle(['Blues_r', 'Greens_r', 'Purples_r', 'Oranges_r', 'RdPu_r'])

basin_color_map1 = {}
basin_color_map2 = {}
for v in basins:
# for v in mcomp.get_cells():
#     if mcomp.dims[v]== 0:
    basin_color_map1[v] = next(palette1)
    basin_color_map2[v] = next(palette2)

print("Complete")

In [None]:
%autoreload

print("Simplifying Morse Complex")

V, coV = homology.simplify_morse_complex(2e-2, V, coV, comp, insert_order)

n = 0
for v in V:
    if v == V[v]:
#         print(v, comp.dims[v])
        n += 1

print("Calculating Morse complex")
mcomp = homology.construct_morse_complex(V, comp, oriented=False)

print("Checking Boundary Operator")
print(homology.check_boundary_op(comp))

mcomp.construct_cofacets()
    
print("Finding basins")
basins = homology.find_basins(mcomp, coV, comp, insert_order, dual=dual)    
    
print("Calculating Morse skeleton")
skeleton = homology.find_morse_skeleton(mcomp, V, comp, 1, insert_order, dual=dual)

    
print("Complete")

In [None]:
%autoreload


# image = np.ones([data.shape[0]*data.shape[1], 3])

# for i in basins:
#     image[list(basins[i])] = basin_color_map1[i]
    
# image[list(skeleton)] = (0.0, 0.0, 0.0)

# # for i in basins:
# #     image[i] = (1.0, 1.0, 1.0)
        
# print(np.min(data), np.max(data))

# fig, ax = plt.subplots(figsize=(16, 16))
# ax.imshow(image.reshape((data.shape[0], data.shape[1], 3)))
# ax.axis('off')
# plt.show()



image = np.zeros([data.shape[0]*data.shape[1], 4])   

ls = mcolors.LightSource()
shaded = ls.hillshade(data, vert_exag=100, dx=1, dy=1).flatten()

for i in basins:
    
#     norm = mpl.colors.Normalize(vmin=np.min(data), vmax=np.max(data))
#     smap = cm.ScalarMappable(norm=norm, cmap=basin_color_map2[i])
#     image[list(basins[i])] = smap.to_rgba(vertex_time[list(basins[i])])
    
    norm = mpl.colors.Normalize(vmin=0, vmax=1)
    smap = cm.ScalarMappable(norm=norm, cmap=basin_color_map2[i])
    image[list(basins[i])] = smap.to_rgba(shaded[list(basins[i])])
    

image[list(skeleton)] = (0.0, 0.0, 0.0, 1.0)

for i in basins:
    image[i] = (1.0, 1.0, 1.0, 1.0)
        
print(np.min(data), np.max(data))

fig, ax = plt.subplots(figsize=(16, 16))
ax.imshow(image.reshape((data.shape[0], data.shape[1], 4)))
ax.axis('off')
plt.show()

In [None]:
%autoreload

print("Calculating persistence pairs...")

filtration = homology.construct_filtration(mcomp, insert_order)

# weights = homology.get_morse_weights(mcomp, V, coV, comp.facets, comp.cofacets)
   
# print("Weights:", weights)

start = time.time()

pairs = homology.compute_persistence(mcomp, filtration)
# (pairs, bcycles) = homology.compute_persistence(mcomp, filtration, 
#                                                             birth_cycles=True, optimal_cycles=False)
# (pairs, bcycles, ocycles) = homology.compute_persistence(mcomp, filtration, 
#                                                             birth_cycles=True, optimal_cycles=True,
#                                                                     weights=weights, relative_cycles=True)

print("Number Pairs:", len(pairs))

end = time.time()

print("Elapsed Time:", end - start)

# print("Pairs:", pairs)
# print("Birth Cycles:", bcycles)
# print("Death Cycles:", ocycles)
        
print("Complete")

In [None]:
TIME = 0

birth = [[] for i in range(mcomp.dim+1)]
death = [[] for i in range(mcomp.dim+1)]

for (i, j) in pairs:
    
    d = mcomp.get_dim(i)
    birth[d].append(insert_order[i][TIME])
    if j is not None:
        death[d].append(insert_order[j][TIME])
    else:
        death[d].append(insert_order[i][TIME])
    
fig = plt.figure(figsize=(8,8))
    
ax1 = fig.add_subplot(1,1,1)

ax1.scatter(birth[0], death[0], marker='.', color='b', label="$0$-cycles")
ax1.scatter(death[1], birth[1], marker='.', color='r', label="$1$-cycles")

ax1.plot(np.linspace(np.min(vertex_time), np.max(vertex_time), 100), np.linspace(np.min(vertex_time), np.max(vertex_time), 100), 'k--')

ax1.set_title(r"$d={}$".format(0))

ax1.set_xlabel(r"Birth [height]")
ax1.set_ylabel(r"Death [height]")

ax1.legend(fontsize='large')

plt.tight_layout()


plt.show()



TIME = 0

birth = [[] for i in range(mcomp.dim+1)]
persistence = [[] for i in range(mcomp.dim+1)]

for (i, j) in pairs:
    
    if j is None or insert_order[i][TIME] == insert_order[j][TIME]:
        continue
        
    
    
    d = mcomp.get_dim(i)
    birth[d].append(insert_order[i][TIME])
    persistence[d].append(insert_order[j][TIME]-insert_order[i][TIME])

fig = plt.figure(figsize=(8,8))
        
ax1 = fig.add_subplot(1,1,1)

ax1.set_yscale('log')

ax1.scatter(birth[0], persistence[0], marker='.', color='b', label="$0$-cycles")
ax1.scatter(birth[1], persistence[1], marker='.', color='r', label="$1$-cycles")

ax1.set_xlabel(r"Birth [height]")
ax1.set_ylabel(r"Persistence [height]")
# ax1.set_ylim(1e-8, 1e0)

# ax1.hlines(10**(-1.9), -0.4, 1.0)

ax1.legend(fontsize='large')

plt.tight_layout()

plt.show()


In [None]:
%autoreload

persistence = []
area = []

for pi, (i, j) in enumerate(pairs):
    
    if pi % 10000 == 0:
        print(pi, "/", len(pairs))
    
    if j is None:
        persistence.append(np.inf)
        area.append(data.shape[0]*data.shape[1])
    else:
        persistence.append(insert_order[j][0] - insert_order[i][0])
        feature = homology.extract_persistence_feature(i, j, mcomp, comp, V, coV, insert_order)
        pixels = homology.convert_to_pixels(feature, comp, insert_order, dual=dual)
        area.append(len(pixels))

print("Complete")

In [None]:
# P = []
# A = []

# for pi, (i, j) in enumerate(pairs):
    
#     if j is not None and comp.dims[i] == 0 and persistence[pi] > 0:
#         P.append(persistence[pi])
#         A.append(area[pi])
    

# g = sns.JointGrid(np.log10(P), np.log10(A))
# g.plot_joint(plt.scatter, marker='.', s=8, color='b', label="Valleys")

# g.plot_marginals(sns.distplot, kde=False)

# ax = g.ax_joint
# ax.set_xlabel(r"$\log_{10}$Persistence [height]")
# ax.set_ylabel(r'$\log_{10}$Area [pixels$^2$]')
# ax.legend(fontsize='large')

# plt.show()


# P = []
# A = []

# for pi, (i, j) in enumerate(pairs):
    
#     if j is not None and comp.dims[i] == 1 and persistence[pi] > 0:
#         P.append(persistence[pi])
#         A.append(area[pi])
    

# g = sns.JointGrid(np.log10(P), np.log10(A))
# g.plot_joint(plt.scatter, marker='.', s=8, color='b', label="Mountains")

# g.plot_marginals(sns.distplot, kde=False)

# ax = g.ax_joint
# ax.set_xlabel(r"$\log_{10}$Persistence [height]")
# ax.set_ylabel(r'$\log_{10}$Area [pixels$^2$]')
# ax.legend(fontsize='large')

# plt.show()


P = []
A = []

for pi, (i, j) in enumerate(pairs):
    
    if j is not None and persistence[pi] > 0:
        P.append(persistence[pi])
        A.append(area[pi])
    

g = sns.JointGrid(np.log10(P), np.log10(A))
g.plot_joint(plt.scatter, marker='.', s=8, color='b', label="Mountains + Valleys")

g.plot_marginals(sns.distplot, kde=False)


P = []
A = []

for pi, (i, j) in enumerate(pairs):
    
    if j is not None and persistence[pi] > 1e-2 and area[pi] > 8e2:
        P.append(persistence[pi])
        A.append(area[pi])
    
regr = linear_model.LinearRegression()
regr.fit(np.log10(P).reshape((len(P), 1)), np.log10(A).reshape((len(P), 1)))

ax = g.ax_joint

ax.scatter(np.log10(P), np.log10(A), marker='.', s=8, color='r')

ax.plot(np.linspace(np.min(np.log10(P)), np.max(np.log10(P)), 100), 
    regr.predict(np.linspace(np.min(np.log10(P)), np.max(np.log10(P)), 100).reshape((100, 1))).reshape(100), 
    'k--', label=r'$A= {:.1f} P^{{{:.1f}}} $'.format(regr.intercept_[0], regr.coef_[0,0]))

ax.set_xlabel(r"$\log_{10}$Persistence [height]")
ax.set_ylabel(r'$\log_{10}$Area [pixels$^2$]')
ax.legend(fontsize='large')


plt.show()

In [None]:
%autoreload

def plot_features(N_features, sorted_pairs):
    
    palette = it.cycle(['Blues_r', 'Greens_r', 'Purples_r', 'RdPu_r'])

    
    ls = mcolors.LightSource()
    shaded = ls.hillshade(data, vert_exag=100, dx=1, dy=1).flatten()
    
    norm = mpl.colors.Normalize(vmin=0, vmax=1)
    smap = cm.ScalarMappable(norm=norm, cmap=plt.cm.Greys_r)

    image = smap.to_rgba(shaded).reshape((data.shape[0]*data.shape[1], 4))

    flat = shaded.flatten()


    cycle_skeleton = set()
    
    last_feature = set()
    last_cycle = set()
        
    n = 0
    count = 0
    while count < N_features:

        (p, (i, j)) = sorted_pairs[n]
        
        n += 1
        
        feature = homology.extract_persistence_feature(i, j, mcomp, comp, V, coV, insert_order)
        pixels = list(homology.convert_to_pixels(feature, comp, insert_order, dual=dual))
        
        last_feature = pixels
        
        if mcomp.get_dim(i) == 0:
            
            smap = cm.ScalarMappable(norm=norm, cmap=next(palette))
            image[pixels] = smap.to_rgba(flat[pixels])
            
            last_basin = pixels
            
            pass
            
        elif mcomp.get_dim(i) == 1:
            
            bound = homology.get_boundary(feature, comp)
            pixels = homology.convert_to_pixels(bound, comp, insert_order, dual=dual)
            
            cycle_skeleton.update(pixels)
            
            last_cycle = pixels

            pass
            
        count += 1

    print(sorted_pairs[N_features-1], mcomp.get_dim(sorted_pairs[N_features-1][1][0]))
     
    smap = cm.ScalarMappable(norm=norm, cmap='Oranges_r')
    image[last_feature] = smap.to_rgba(flat[last_feature])
    
    image[list(cycle_skeleton)] = (0.0, 0.0, 0.0, 1.0)
    
        
    if mcomp.get_dim(sorted_pairs[N_features-1][1][0]) == 1:
        
        h = "e7298a"
        color = list(int(h[i:i+2], 16) / 256.0 for i in (0, 2 ,4))
        color.append(1)
        
        image[list(last_cycle)] = color
        
        pass
          
    
    fig, ax = plt.subplots(figsize=(16, 16))
    
    ax.imshow(image.reshape((data.shape[0], data.shape[1], 4)))


    ax.axis('off')
    plt.show()
        
persistence = []
for (i, j) in pairs:
    if j is not None:
        persistence.append((insert_order[j][0] - insert_order[i][0], (i, j)))

persistence = sorted(persistence, reverse=True)

print(len(persistence))

for n in range(len(persistence)):
    
    print(n)
    
    if n > 2:
        break
    
    plot_features(n+1, persistence)
            

In [None]:
TIME = 0
data = {}
for (i, j) in pairs:
    if j is not None:
        data[(i, j)] = insert_order[j][TIME] - insert_order[i][TIME]
    else:
        data[(i, j)] = np.inf
        
pickle.dump(data, open("{}_X{}Y{}_persist.pkl".format(label, X, Y), 'wb'))

In [None]:
persist_data = {}

for i in range(3):
    for j in range(3):

        persist_data[(i, j)] = pickle.load(open("{}_X{}Y{}_persist.pkl".format("Z16", i, j), 'rb'))


fig, ax = plt.subplots(1, 1, figsize=(8,8))

for (i, j) in persist_data:
    persist = []

    for pair in persist_data[(i, j)]:
        if persist_data[(i, j)][pair] != np.inf:
            persist.append(persist_data[(i, j)][pair])



    sns.distplot(np.log10(persist), ax=ax, kde=False, norm_hist=True, 
                 hist_kws={"cumulative": True, "histtype": "step", "linewidth": 1, "alpha":1.0, "color": 'b'})
    
    
persist_data = {}

for i in range(1):
    for j in range(1):

        persist_data[(i, j)] = pickle.load(open("{}_X{}Y{}_persist.pkl".format("Z8", i, j), 'rb'))


for (i, j) in persist_data:
    persist = []

    for pair in persist_data[(i, j)]:
        if persist_data[(i, j)][pair] != np.inf:
            persist.append(persist_data[(i, j)][pair])



    sns.distplot(np.log10(persist), ax=ax, kde=False, norm_hist=True, 
                 hist_kws={"cumulative": True, "histtype": "step", "linewidth": 1, "alpha":1.0, "color": 'r'})
    
persist_data = {}

for i in range(1):
    for j in range(1):

        persist_data[(i, j)] = pickle.load(open("{}_X{}Y{}_persist.pkl".format("Z4", i, j), 'rb'))


for (i, j) in persist_data:
    persist = []

    for pair in persist_data[(i, j)]:
        if persist_data[(i, j)][pair] != np.inf and persist_data[(i, j)][pair] != 0:
            persist.append(persist_data[(i, j)][pair])



    sns.distplot(np.log10(persist), ax=ax, kde=False, norm_hist=True, 
                 hist_kws={"cumulative": True, "histtype": "step", "linewidth": 1, "alpha":1.0, "color": 'g'})

    
persist_data = {}

for i in range(1):
    for j in range(1):

        persist_data[(i, j)] = pickle.load(open("{}_X{}Y{}_persist.pkl".format("Z1", i, j), 'rb'))

for (i, j) in persist_data:
    persist = []

    for pair in persist_data[(i, j)]:
        if persist_data[(i, j)][pair] != np.inf and persist_data[(i, j)][pair] != 0:
            persist.append(persist_data[(i, j)][pair])



    sns.distplot(np.log10(persist), ax=ax, kde=False, norm_hist=True, 
                 hist_kws={"cumulative": True, "histtype": "step", "linewidth": 2, "alpha":1.0, "color": 'm'})

ax.set_xlabel(r"$\log_{10}$Persistence")

plt.show()

In [None]:
from sklearn import mixture, cluster, manifold

P = []
A = []

for pi, (i, j) in enumerate(pairs):
    
    if j is not None and persistence[pi] > 0:
        P.append(persistence[pi])
        A.append(area[pi])

X = np.log10(np.vstack((P, A)).T)

n_clusters = 2

print("Gaussian Mixture Model")

model = mixture.GaussianMixture(n_components=n_clusters, covariance_type='full')

model.fit(X)
print(model.bic(X))
      
Y = model.predict(X)

fig, ax = plt.subplots(1,1, figsize=(6,6))

palette = it.cycle(sns.color_palette("deep"))

for i ,(mean, cov) in enumerate(zip(model.means_, model.covariances_)):
    
    color = next(palette)
    
    ax.scatter(X[Y == i, 0], X[Y == i, 1], marker='.', s=8, color=color)


#     v, w = la.eigh(cov)
#     angle = np.arctan2(w[0][1], w[0][0])
#     angle = 180. * angle / np.pi  # convert to degrees
#     v = 2. * np.sqrt(2.) * np.sqrt(v)
#     ell = mpl.patches.Ellipse(mean, v[0], v[1], 180. + angle, color='k')
#     ell.set_clip_box(ax.bbox)
#     ell.set_alpha(.5)
#     ax.add_artist(ell)


plt.show()

print("K-Mean Clustering")

model = cluster.KMeans(n_clusters=n_clusters)

model.fit(X)
      
Y = model.predict(X)

fig, ax = plt.subplots(1,1, figsize=(6,6))

palette = it.cycle(sns.color_palette("deep"))

for i , center in enumerate(model.cluster_centers_):
    
    color = next(palette)
    
    ax.scatter(X[Y == i, 0], X[Y == i, 1], marker='.', s=8, color=color)


plt.show()



# print("Mean-Shift")



# model = cluster.MeanShift()

# model.fit(X)
      
# Y = model.predict(X)

# fig, ax = plt.subplots(1,1, figsize=(6,6))

# palette = it.cycle(sns.color_palette("deep"))

# for i , center in enumerate(model.cluster_centers_):
    
#     color = next(palette)
    
#     ax.scatter(X[Y == i, 0], X[Y == i, 1], marker='.', s=8, color=color)


# plt.show()

