In [None]:
import numpy as np
from skimage import io
from utils import grainPreprocess, grainMark
from numpy.lib.stride_tricks import sliding_window_view

from matplotlib import pyplot as plt

from skimage import io, color, filters, morphology, util
from skimage.measure import EllipseModel
from skimage.color import rgb2gray
from skimage import filters, util
from skimage.morphology import disk, skeletonize, ball
from skimage.measure import approximate_polygon
from skimage import transform
import copy
from PIL import Image, ImageDraw, ImageFilter, ImageOps

from matplotlib import cm
import networkx as nx
from tqdm.notebook import tqdm

from bresenham import bresenham
import pandas as pd

In [None]:
def preprocess_image_1(image):
    if len(image.shape)==3:
        image = color.rgb2gray(image)

    image = filters.rank.median(image, morphology.disk(3))

    global_thresh = filters.threshold_otsu(image)
    image = image > global_thresh
    binary = image*255
    binary = binary.astype(np.uint8)

    grad = abs(filters.rank.gradient(binary, morphology.disk(1)))
    bin_grad = (1 - binary + grad) * 127
    bin_grad = np.clip(bin_grad, 0, 255).astype(np.uint8)

    return bin_grad

def preprocess_image_2(image):
    if len(image.shape)==3:
        image = color.rgb2gray(image)

    image = filters.rank.median(image, morphology.disk(3))

    global_thresh = filters.threshold_otsu(image)
    image = image > global_thresh
    binary = image*255
    binary = binary.astype(np.uint8)

    grad = abs(filters.rank.gradient(binary, morphology.disk(1)))
    # bin_grad = (1 - binary ) * 127
    bin_grad = grad
    bin_grad = np.clip(bin_grad, 0, 255).astype(np.uint8)

    return bin_grad

def draw_edges(image, cnts, color=(0, 139, 139), r=4, e_width=5, l_width=4):

    img = copy.copy(image)
    draw = ImageDraw.Draw(img)

    for j, cnt in enumerate(cnts):
        if len(cnt) > 1:
            point = cnt[0]
            x1, y1 = point[1], point[0]

            for i, point2 in enumerate(cnt):
                p2 = point2

                x2, y2 = p2[1], p2[0]

                draw.ellipse((y2 - r, x2 - r, y2 + r, x2 + r), fill=color, width=e_width)
                draw.line((y1, x1, y2, x2), fill=(255, 140, 0), width=l_width)
                x1, y1 = x2, y2

    return img

# Plot all points

In [None]:
orig_img = io.imread('../datasets/original/o_bc_left/Ultra_Co6_2/Ultra_Co6_2-001.jpeg')[:200,:200]

r=2
eps = 15
border = 10
tol = 3

orig_img = preprocess_image_1(orig_img)
img_tmp = Image.fromarray(orig_img)
img_viz = ImageOps.expand(img_tmp, border=border, fill='white')
cnts=grainMark.get_contours(np.array(img_viz),tol=tol)

img_shape=np.array(img_viz.size)


# coord2index
image_nodes_coord2nodes_index={}
nodes_index2global_nodes_coord={}
num_of_nodes=0

for points in cnts:
    for point in points:
        x,y = point[0],point[1]
        image_nodes_coord2nodes_index[(x,y)]=num_of_nodes
        nodes_index2global_nodes_coord[num_of_nodes]=(x,y)
        num_of_nodes+=1

# entry points
x_entry=[]
y_entry=[]

entry_dict={}

y_entry_max=0

for points in cnts:
    for point in points:
        if point[1]<eps:
            x,y = point[0],point[1]
            x_entry.append(x)
            y_entry.append(y)   
            # condition to make end exit poits below start points 
            if y_entry_max<y:
                y_entry_max=y
            
            index=image_nodes_coord2nodes_index[(x,y)]
            entry_dict[index]=1

# exit points
x_exit=[]
y_exit=[]
exit_dict={}

for points in cnts:
    for point in points:
        if (point[0] < eps or img_shape[0] - point[0] < eps or img_shape[1] - point[1] < eps) and point[1]>y_entry_max:
            x,y = point[0],point[1]
            x_exit.append(x)
            y_exit.append(y)
            index=image_nodes_coord2nodes_index[(x,y)]
            exit_dict[index]=1

img_drawings = copy.copy(img_viz).convert('RGB')
img_drawings=draw_edges(img_drawings, cnts=cnts, r=2, l_width=1)
draw = ImageDraw.Draw(img_drawings)

# entry blue
for i in range(len(x_entry)):
    x = x_entry[i]
    y = y_entry[i]
    draw.ellipse((x - r, y - r, x + r, y + r), fill=(0,0,200), width=1)
    
# exit red
for i in range(len(x_exit)):
    x = x_exit[i]
    y = y_exit[i]
    draw.ellipse((x - r, y - r, x + r, y + r), fill=(150,0,0), width=1)

plt.figure(figsize=(7,7))
plt.imshow(img_drawings,cmap='gray', origin='lower')
plt.xlabel('x',fontsize=15)
plt.ylabel('y',fontsize=15)
plt.show()

1) image_nodes_coord2nodes_index
2) nodes_index2global_nodes_coord
3) image_node_coord2node_index

1) grid_cell_coord2grid_cell_index
2) grid_cell_index2grid_cell_coord
3) image_coord2grid_cell_index

In [None]:
# process nodes of graph
# add all nodes to graph
g = nx.DiGraph()
image_node_coord2node_index = np.zeros(img_shape)
for key in range(num_of_nodes):
    x,y=nodes_index2global_nodes_coord[key]
    image_node_coord2node_index[x,y]=key
    g.add_node(key, pos=(x,y))
    
# pos = nx.get_node_attributes(g, 'pos')
# plt.figure(figsize = (5,5))
# nx.draw(g, pos, with_labels=True, node_color='lightblue', node_size=500, font_size=10)

In [None]:
# create grid
cell_size=10
grid_size=np.int32(img_shape/cell_size)

number_of_cells=0
grid_cell_coord2grid_cell_index={}
grid_cell_index2grid_cell_coord={}
image_coord2grid_cell_index = np.zeros(img_shape)
for xi in range(grid_size[0]):
    for yi in range(grid_size[1]):   
        grid_cell_coord2grid_cell_index[(xi,yi)]=number_of_cells
        grid_cell_index2grid_cell_coord[number_of_cells]=(xi,yi)
        
        # map of grid cell indices
        image_coord2grid_cell_index[xi*cell_size:(xi+1)*cell_size,
                                 yi*cell_size:(yi+1)*cell_size
                                ]=np.full((cell_size,cell_size), number_of_cells)
        number_of_cells+=1

# plt.imshow(image_coord2grid_cell_index[:10,:100])    

grid = np.array(sliding_window_view(image_node_coord2node_index, (cell_size, cell_size))[::cell_size, ::cell_size])
grid_summed = grid.sum(axis=(2,3))
plt.imshow(grid_summed,cmap='gray', origin='lower')

# Static eps search

In [None]:
eps=100
line_eps = 2
border_eps = 7
    
g = nx.DiGraph()
image_node_coord2node_index = np.zeros(img_shape,dtype=np.int32)
for key in range(num_of_nodes):
    x,y=nodes_index2global_nodes_coord[key]
    image_node_coord2node_index[x,y]=key
    g.add_node(key, pos=(x,y)) 
    
tmp_img = preprocess_image_2(orig_img)
img_tmp = Image.fromarray(tmp_img)
grad_map = ImageOps.expand(img_tmp, border=border, fill='black')    
grad_map=np.array(grad_map)


m=[]
# grid search
for start_node_index in tqdm(range(num_of_nodes)):
    
    # choose cell
    start_node_x,start_node_y=nodes_index2global_nodes_coord[start_node_index]
    
    # left x slice border
    if start_node_x-eps<0:
        left_border_x=eps-start_node_x
    else:
        left_border_x=start_node_x-eps-2

    # right x slice border
    if start_node_x+eps>image_node_coord2node_index.shape[0]:
        right_border_x=image_node_coord2node_index.shape[0]
    else:
        right_border_x=start_node_x+eps+2

    # upper_border
    if start_node_y+eps>image_node_coord2node_index.shape[1]-1:
        upper_border=image_node_coord2node_index.shape[1]-1
    else:
        upper_border=start_node_y+eps
        
    map_slice = image_node_coord2node_index[left_border_x:right_border_x,start_node_y:upper_border]
    
    nodes_indices_indices = np.where(map_slice.flatten()!=0)
    nodes_indices =map_slice.flatten()[nodes_indices_indices]
    
    for node_index in (nodes_indices):
        end_node_x,end_node_y=nodes_index2global_nodes_coord[node_index]
        
        mean_border_pixels=[]
        for p in range(0 - line_eps, 1 + line_eps):
            line_coords=np.array(list(bresenham(start_node_x+p, start_node_y,end_node_x+p, end_node_y)))
            line_coords_pixels=grad_map[line_coords[:,0],line_coords[:,0]].flatten()
            border_pixels_num = np.where(line_coords_pixels==255)[0].shape[0]
            mean_border_pixels.append(border_pixels_num)
            
        m.extend(mean_border_pixels)
        line_border_pixels_mean=np.mean(mean_border_pixels)
        if line_border_pixels_mean<border_eps:
            g.add_edge(start_node_index,node_index)
        # m.append(np.mean(mean_pixels))

pos = nx.get_node_attributes(g, 'pos')
plt.figure(figsize = (10,10))
nx.draw(g, pos, with_labels=True, node_color='lightblue', node_size=500, font_size=10)

In [None]:
plt.figure(figsize=(7,7))
plt.imshow(img_drawings,cmap='gray', origin='lower')
plt.xlabel('x',fontsize=15)
plt.ylabel('y',fontsize=15)
plt.show()

In [None]:
%%time

list(nx.all_simple_paths(g, source=59, target=7))

In [None]:
plt.hist(m,bins=100)
plt.xlim(0,20)
plt.show()

# Dynamic grid search (not done)

In [None]:
def check_borders(point, shape):
    if point[0]>=0 and point[1]>=0 and point[0]<=shape[0] and point[1]<=shape[1]:
        return True
    else:
        return False
K=10
img_viz_numpy=np.array(img_viz)

#127, 255

m=[]
# grid search
for start_node_index in tqdm(range(num_of_nodes)):
    
    # choose cell
    start_node_x,start_node_y=nodes_index2global_nodes_coord[start_node_index]
    grid_cell_id=image_coord2grid_cell_index[start_node_x,start_node_y]
    cell_x,cell_y=grid_cell_index2grid_cell_coord[grid_cell_id]
    
    if cell_x-K//2<0:
        left_border_x=K//2-cell_x
    else:
        left_border_x=cell_x-K//2-2

    # right x slice border
    if cell_y+K//2>grid.shape[0]:
        right_border_x=grid.shape[0]
    else:
        right_border_x=cell_y+K//2+2

    # upper_border
    if cell_y+K>grid.shape[1]-1:
        upper_border=grid.shape[1]-1
    else:
        upper_border=cell_y+K

    grid_summed_slice=grid_summed[left_border_x:right_border_x,cell_y:upper_border]
    summed_grid_cell_x,summed_grid_cell_y=np.where(grid_summed_slice!=0)
    grid_slice=grid[summed_grid_cell_x,summed_grid_cell_y]

    nodes_indices_indices = np.where(grid_slice.flatten()!=0)
    nodes_indices =grid_slice.flatten()[nodes_indices_indices]

    for node_index in (nodes_indices):
        end_node_x,end_node_y=nodes_index2global_nodes_coord[node_index]
        mean_pixels=[]
        for p in range(-2,3):
            line_coords=np.array(list(bresenham(start_node_x+p, start_node_y,end_node_x+p, end_node_y)))
            mean_pixel=np.mean(img_viz_numpy[line_coords[:,0],line_coords[:,0]])
            m.append(mean_pixel)
            mean_pixels.append(mean_pixel)
        line_mean=np.mean(mean_pixels)
        if line_mean<160 or line_mean>222:
            g.add_edge(start_node_index,node_index)
        # m.append(np.mean(mean_pixels))
   

In [None]:
plt.imshow(img_viz_numpy,cmap='gray', origin='lower')

In [None]:
pos = nx.get_node_attributes(g, 'pos')
plt.figure(figsize = (5,5))
nx.draw(g, pos, with_labels=True, node_color='lightblue', node_size=500, font_size=10)

In [None]:
plt.hist(m,bins=20)

In [None]:
np.array(list(bresenham(-1, -4, 3, 2)))

In [None]:
#     for xi in range(-K//2,K//2+1):
#         cell_x_tmp = cell_x + xi
#         for yi in range(0,K):
#             cell_y_tmp = cell_y + yi
            
#             if check_borders((cell_x_tmp,cell_y_tmp),img_shape):



    # wave_flag=True
    # wave=0
    # indices=[]

    # while wave_flag:
        # center cell
#         center_point=[node_x, node_y+1]
        
#         if check_borders(center_point):
#             indices.append(center_point)
        
#         # border cells
#         for i in range(wave+1):
#             left_point=[node_x-1-i, node_y-1-i]
#             right_point=[node_x+1+i, node_y+1+i]
            
#         if check_borders(left_point):
#             indices.append(left_point)
            
#         if check_borders(right_point):
#             indices.append(right_point)

In [None]:
# def wave_coords(x,y,grid,wave_step):
width,height=grid.shape[:2]


In [None]:
grid.shape

In [None]:
cell_x

In [None]:
plt.imshow(grid_wraped,cmap='gray', origin='lower' )
plt.show()

In [None]:
plt.imshow(image_coord2grid_cell_index[:20],cmap='gray', origin='lower' )
plt.show()

In [None]:
grid_cell_index2grid_cell_coord

In [None]:
255-img_with_border

In [None]:
pos = nx.get_node_attributes(g, 'pos')
nx.draw(g, pos, with_labels=True, node_color='lightblue', node_size=500, font_size=10)
plt.show()

# Deprecated

Entry points 22
number of nodes 1770

In [None]:
# contour_points = cnts[3]
gs=[]


for i,contour_points in enumerate(cnts):
    g = nx.Graph()

    for i, point in enumerate(contour_points):
        g.add_node(i, pos=point)


    for i in range(len(contour_points) - 1):
        g.add_edge(i, i + 1)

    g.add_edge(len(contour_points) - 1, 0)
    gs.append(g)

names = tuple([f"g{i}-" for i in range(len(cnts))])

g=nx.union_all(gs,rename=names)

pos = nx.get_node_attributes(g, 'pos')

plt.figure(figsize = (10,10))
nx.draw(g, pos, with_labels=True, node_color='lightblue', node_size=500, font_size=10)
plt.show()

In [None]:
G = nx.dorogovtsev_goltsev_mendes_graph(3)
nx.draw(G, with_labels=True, node_color='lightblue', node_size=500, font_size=10)

In [None]:
# G = nx.path_graph(5)

# input 22 points
# output 66 points
# G = nx.complete_multipartite_graph(4,5)
G = nx.dorogovtsev_goltsev_mendes_graph(7)
# nx.draw(G, with_labels=True, node_color='lightblue', node_size=500, font_size=10)

G.number_of_nodes()

In [None]:
%%time

list(nx.all_simple_paths(G, source=0, target=7))