In [8]:
%matplotlib qt
import numpy as np
import matplotlib.pyplot as plt

## commnet this block as it is specifically for authors machine
import sys
import platform
if platform.system() == 'Darwin':
	sys.path.append('/Users/jiadongdan/Dropbox/stempy')
else:
	sys.path.append('D:\\Dropbox\\stempy')
## commnet this block as it is specifically for authors machine

from stempy.io import *
from stempy.utils import *
from stempy.plot import *
from stempy.datasets import *
from stempy.feature import *
from stempy.manifold import *
from stempy.clustering import *
from stempy.spatial import *
from stempy.signal import *

In [13]:
file_name = dp+"WS2Te2_rotated.tif"
img = load_image(file_name, normalized=True)[17:17+800, 32:32+800]

In [14]:
size = 32
patch = img[280-size:280+size, 222-size:222+size]

In [15]:
from skimage.feature import match_template
hh = match_template(img, patch, pad_input=True)

In [16]:
threshold = 0.5
pp = local_max(hh, min_distance=8, threshold=threshold, plot=False)

In [17]:
pts = load_pickle('pts.pkl')
lbs = load_pickle('lbs.pkl')

In [18]:
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestNeighbors

def estimate_r0(pts, k=7):
    nbrs = NearestNeighbors(n_neighbors=k).fit(pts)
    d, ind = nbrs.kneighbors(pts)
    d = d[:, 1:].ravel()[:, np.newaxis]
    
    kmeans = KMeans(n_clusters=2, random_state=0).fit(d)
    lbs = kmeans.labels_
    d1 = (d[lbs==0]).mean()
    d2 = (d[lbs==1]).mean()
    return (d1+d2)/2

def get_hex_p0_uv(pts, r=None):
    if r is None:
        r = estimate_r0(pts)
    nbrs = NearestNeighbors(n_neighbors=7).fit(pts)
    inds = nbrs.kneighbors(pts, return_distance=False)
    
    pts1 = (pts[inds] - pts[:, np.newaxis, :]).reshape(-1, 2)
    r1 = np.hypot(pts1[:, 0], pts1[:, 1])
    mask = np.logical_and(r1 <= r, r1>0)
    pts2 = pts1[mask]
    
    kmeans = KMeans(n_clusters=6, random_state=0).fit(pts2)
    lbs = kmeans.labels_
    p6 = np.array([pts2[lbs==e].mean(axis=0) for e in range(6)])
    angles = np.abs(np.rad2deg(np.arctan2(p6[:, 1], p6[:, 0])))
    i, j = np.argsort(angles)[0:2]
    u, v = p6[i], p6[j]
    
    # get p0
    ind = nbrs.kneighbors(pts.mean(axis=0)[np.newaxis, :], return_distance=False)[0][0]
    p0 = pts[ind]
    return (p0, u, v)

In [19]:
p0, u, v = get_hex_p0_uv(pp)

In [20]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mc
from matplotlib.path import Path
from matplotlib.patches import Polygon
from sklearn.neighbors import NearestNeighbors
from scipy.linalg import inv
from skimage.measure import label, regionprops
from skimage.transform import estimate_transform

from matplotlib.collections import LineCollection

import base64
import io
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from skimage.transform import rescale
from scipy.sparse.csgraph import connected_components

import numbers
from numpy.lib.stride_tricks import as_strided

from sklearn.cluster import DBSCAN


def _get_p0_u_v(pts):
    nbrs = NearestNeighbors(n_neighbors=7, algorithm='ball_tree').fit(pts)
    ind = nbrs.kneighbors(pts.mean(axis=0)[np.newaxis, :], return_distance=False)[0][0]

    p0 = pts[ind]
    inds = nbrs.kneighbors(p0[np.newaxis, :], return_distance=False)[0][1:]
    pp6 = pts[inds] - p0
    angles = np.abs(np.rad2deg(np.arctan2(pp6[:, 1], pp6[:, 0])))
    ind2 = inds[np.argsort(angles)][0:2]
    u = pts[ind2[0]] - p0
    v = pts[ind2[1]] - p0
    return p0, u, v

def _remove_duplicate_pts(pts):
    return np.unique(pts, axis=0)


def points_to_cell_inds(pts, threshold=None):
    if threshold is None:
        threshold = _estimate_r0(pts) * 0.2
    p0, u, v = _get_p0_u_v(pts)
    pts1 = pts + u
    pts2 = pts + v
    pts3 = pts + u + v
    nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(pts)
    d1, ind1 = nbrs.kneighbors(pts1)
    d2, ind2 = nbrs.kneighbors(pts2)
    d3, ind3 = nbrs.kneighbors(pts3)
    d123 = np.hstack([d1, d2, d3])
    ind1234 = np.hstack([np.arange(len(pts))[:, np.newaxis], ind1, ind3, ind2])
    mask = d123.mean(axis=1) < threshold
    inds = ind1234[mask]
    return inds


# estimate nearest neighbor radius
def _estimate_r0(pts):
    nbrs = NearestNeighbors(n_neighbors=2, algorithm='ball_tree').fit(pts)
    d, ind = nbrs.kneighbors(pts)
    d = d[:, 1]
    return d.mean()


def get_ordered_lbs(pts, lbs, inds):
    nmax = np.max([len(e) for e in inds])
    lbs_ordered = []
    for ii, row in enumerate(inds):
        if len(row) == 0:
            lbs_select = [-1] * nmax
        elif len(row) == 1:
            lbs_select = [lbs[row[0]]] + [-1] * (nmax - 1)
        else:
            pts_select = pts[row] - (pts[row]).mean(axis=0)
            angles = np.arctan2(pts_select[:, 1], pts_select[:, 0]) + np.pi
            # sort by angle
            lbs_row = lbs[row][np.argsort(angles)]
            # rolling min to first
            idx = np.argmin(lbs_row)
            lbs_select = np.roll(lbs_row, -idx).tolist() + [-1] * (nmax - len(angles))
        lbs_ordered.append(lbs_select)
    return np.array(lbs_ordered)


def reorder_lbs(lbs, vals=None, mode='min'):
    unique, counts = np.unique(lbs, return_counts=True)
    if vals is None:
        vals = counts
    if mode == 'min':
        lbs_order = np.argsort(vals)
    else:
        lbs_order = np.argsort(vals)[::-1]
    order_dict = dict(zip(lbs_order, unique))
    lbs = np.vectorize(order_dict.get)(lbs)
    return lbs

def get_rowcol_lbs(X, eps=0.3):
    # clustering
    dbscan = DBSCAN(eps=eps, min_samples=1).fit(X)
    lbs = dbscan.labels_
    # sort lbs
    vals = np.array([X[lbs==e].mean() for e in np.unique(lbs)])
    lbs = reorder_lbs(lbs, vals)
    return lbs

def pts2grid(pts):
    lbs1 = get_rowcol_lbs(pts[:, 0:1])
    lbs2 = get_rowcol_lbs(pts[:, 1:2])
    pts1 = pts.copy()
    for i, e in enumerate(np.unique(lbs2)):
        kk = pts1[lbs2==e]
        kk[:, 1] = i
        pts1[lbs2==e] = kk

    for i, e in enumerate(np.unique(lbs1)):
        kk = pts1[lbs1==e]
        kk[:, 0] = i
        pts1[lbs1==e] = kk
    return pts1.astype(int)

class HexCells:

    def __init__(self, pts, dx=0, dy=0):
        dxdy = np.array([dx, dy])
        self.pts = _remove_duplicate_pts(pts) + dxdy
        p0, self.u, self.v = _get_p0_u_v(self.pts)
        self.inds = points_to_cell_inds(self.pts)
        self.cells = self.pts[self.inds]

        # region(cell) centers
        self.centers = np.array([self.pts[ind].mean(axis=0) for ind in self.inds])
        
        p0, u, v = get_hex_p0_uv(self.centers)
        uv = np.vstack([u, v])
        ij = (self.centers - p0).dot(inv(uv))
        self.centers_ = pts2grid(ij)
        
        # get transform
        self.tform = estimate_transform('affine', self.centers_, self.centers)
        self.pts_ = self.tform.inverse(self.pts)

        self.pos = None
        self.lbs = None
        self.cell_lbs = None
        
    
    @property
    def image(self):
        if self.cell_lbs is None:
            raise ValueError('cell_lbs is not set.')
        vmax = max(np.ptp(self.centers_[:, 0]), np.ptp(self.centers_[:, 1]))
        cells_matrix = np.ones(shape=(vmax + 1, vmax + 1)) * (-1)
        # need cell_lbs
        for idx, (i, j) in enumerate(self.centers_):
            cells_matrix[j, i] = self.cell_lbs[idx]
        return cells_matrix
    
    @property
    def id_image(self):
        id_matrix = np.ones_like(self.image)*(-1)
        for idx, (i, j) in enumerate(self.centers_):
            id_matrix[j, i] = idx
        return id_matrix.astype(int)
    
    # this is important
    def set_cell_lbs(self, pos, lbs):
        polys = self.pts[self.inds]
        # indices of points inside each cell
        inds = []
        # len(self.pts) is equal to number of cells
        self.pos = []
        self.lbs = []
        for poly in polys:
            ind = np.where(Path(poly).contains_points(pos) == 1)[0]
            inds.append(ind)
            self.pos.append(pos[ind])
            self.lbs.append(lbs[ind])
        lbs_ordered = get_ordered_lbs(pos, lbs, inds)
        lbs_ordered_unique, self.cell_lbs = np.unique(lbs_ordered, axis=0, return_counts=False, return_inverse=True)
        self.cell_lbs = reorder_lbs(self.cell_lbs, mode='max')
        
    @property
    def connected_cells(self):
        threshold = 0
        connectivity = 1
        # label image
        lbs_image = label(self.image > threshold, connectivity=connectivity)
        # regions
        region_props = regionprops(lbs_image)

        # areas and components
        areas = np.array([e.area for e in region_props if e.area > 1])
        components = [self.image[e.slice] for e in region_props if e.area > 1]
        ids = [self.id_image[e.slice] for e in region_props if e.area > 1]
        masks = [(e > 0) * 1 for e in components]
        components = [components[i]*masks[i] for i in range(len(masks))]
        ids = [ids[i]*masks[i] for i in range(len(masks))]
        # sort components
        ind = np.argsort(areas)
        self.components = np.array([components[e] for e in ind], dtype=object)
        self.ids = np.array([ids[e] for e in ind], dtype=object)

        # unique areas and components
        self.components, cs_lbs = unique_components(self.components)
        self.ids = np.array([self.ids[np.where(cs_lbs == e)[0][0]] for e in np.unique(cs_lbs)], dtype=object)
        
        _, self.cnts = np.unique(cs_lbs, return_counts=True)

        cs = []
        for i in range(len(self.ids)):
            idx = self.ids[i].ravel()[self.ids[i].ravel() != 0]
            pos = np.vstack([self.pos[e] for e in idx])
            lbs = np.hstack([self.lbs[e] for e in idx])
            cs.append(ConnectedCells(self.components[i], self.cells[idx], pos, lbs))
        return cs

    def get_connected_cells(self, min_area=3, max_area=10, colors=None):
        threshold = 0
        connectivity = 1
        # label image
        lbs_image = label(self.image > threshold, connectivity=connectivity)
        # regions
        region_props = regionprops(lbs_image)

        # areas and components
        areas = np.array([e.area for e in region_props if np.logical_and(e.area >=min_area, e.area<=max_area)])
        components = [self.image[e.slice]*(e.image) for e in region_props if np.logical_and(e.area >=min_area, e.area<=max_area)]
        ids = [self.id_image[e.slice]*(e.image) for e in region_props if np.logical_and(e.area >=min_area, e.area<=max_area)]
        masks = [(e > 0) * 1 for e in components]
        components = [components[i] * masks[i] for i in range(len(masks))]
        ids = [ids[i] * masks[i] for i in range(len(masks))]
        # sort components
        ind = np.argsort(areas)
        self.components = np.array([components[e] for e in ind], dtype=object)
        self.ids = np.array([ids[e] for e in ind], dtype=object)

        # unique areas and components
        self.components, cs_lbs = unique_components(self.components)
        self.ids = np.array([self.ids[np.where(cs_lbs == e)[0][0]] for e in np.unique(cs_lbs)], dtype=object)
        
        _, self.cnts = np.unique(cs_lbs, return_counts=True)

        cs = []
        for i in range(len(self.ids)):
            idx = self.ids[i].ravel()[self.ids[i].ravel() != 0]
            pos = np.vstack([self.pos[e] for e in idx])
            lbs = np.hstack([self.lbs[e] for e in idx])
            cs.append(ConnectedCells(self.components[i], self.cells[idx], pos, lbs, colors=colors))
        return cs
    
    def plot(self, ax=None, cs=None, **kwargs):
        if ax is None:
            fig, ax = plt.subplots(1, 1, figsize=(7.2, 7.2))

        # ax.scatter(self.pts[:, 0], self.pts[:, 1], color='g', s=10)
        if cs is None:
            cs = ['white']+['C{}'.format(e) for e in range(10)]
        for i, e in enumerate(self.inds):
            if self.cell_lbs is None:
                fc = 'C0'
            else:
                fc = cs[self.cell_lbs[i] % 10]
            poly = Polygon(self.pts[e], alpha=0.5, ec='k', fc=fc)
            ax.add_patch(poly)
        ax.axis('equal')
        
        
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
# CellsImages
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=

def unique_components(components):
    def is_same(e1, item):
        e2 = np.rot90(e1, k=1)
        e3 = np.rot90(e1, k=2)
        e4 = np.rot90(e1, k=3)
        s1 = np.array_equal(e1, item)
        s2 = np.array_equal(e2, item)
        s3 = np.array_equal(e3, item)
        s4 = np.array_equal(e4, item)
        return np.any([s1, s2, s3, s4])

    # now this is a graph
    g = np.array([is_same(e1, e2) for e1 in components for e2 in components])
    g = g.reshape(len(components), len(components))
    num, lbs = connected_components(g)
    components_ = np.array([components[np.where(lbs == e)[0][0]] for e in np.unique(lbs)], dtype=object)
    return components_, lbs


mpl_21 = ['1', '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
          '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf',
          '#3397dc', '#ff993e', '#3fca3f', '#df5152', '#a985ca',
          '#ad7165', '#e992ce', '#999999', '#dbdc3c', '#35d8e9']


def add_grids(rgba, s=20):
    h, w = rgba.shape[0:2]

    ys = np.arange(0, h, 20)[1:]
    xs = np.arange(0, w, 20)[1:]

    d = 1
    for y in ys:
        rgba[y - d:y + d, ::] = [255, 255, 255, 255]

    for x in xs:
        rgba[:, x - d:x + d, :] = [255, 255, 255, 255]

    return rgba

def cells2points(cells, return_centers=False):
    xs, ys = np.nonzero(cells)
    centers = (np.vstack([xs, ys]).T).astype(float)
    p1 = centers.copy()
    p2 = centers.copy()
    p3 = centers.copy()
    p4 = centers.copy()
    p1[:, 0] += 0.5
    p1[:, 1] += 0.5
    p2[:, 0] -= 0.5
    p2[:, 1] -= 0.5
    p3[:, 0] += 0.5
    p3[:, 1] -= 0.5
    p4[:, 0] -= 0.5
    p4[:, 1] += 0.5
    pts = np.vstack([p1, p2, p3, p4])
    if return_centers:
        return np.unique(pts, axis=0), centers
    else:
        return np.unique(pts, axis=0)


def sort_inds_by_angle(pts, inds):
    inds_sort = []
    ns = []
    for ind in inds:
        pts1 = pts[ind]
        center = pts1.mean(axis=0)
        pts2 = pts1 - center
        angles = np.arctan2(pts2[:, 1], pts2[:, 0]) + np.pi
        ind_ = ind[np.argsort(angles)]
        inds_sort.append(ind_)
        ns.append(len(ind))
    if np.unique(ns).shape == (1,):
        return np.array(inds_sort)
    else:
        return inds_sort


def get_pairs(pts, centers):
    nbrs = NearestNeighbors(n_neighbors=4, algorithm='auto').fit(pts)
    inds = nbrs.kneighbors(centers, return_distance=False)
    inds = sort_inds_by_angle(pts, inds)

    pairs = np.dstack([inds, np.roll(inds, 1, axis=1)]).reshape(-1, 2)
    pairs = np.unique(np.sort(pairs, axis=1), axis=0)
    return pairs, inds


class ConnectedCells:

    def __init__(self, data, cells=None, pos=None, lbs=None, tform=None, colors=None):
        self.mask = data > 0
        self.data = data * self.mask

        self.shape = data.shape
        self.h, self.w = data.shape
        self.num = np.sum(self.mask)

        if colors is None:
            self.colors = np.array([mc.to_rgba(e) for e in mpl_21])
        else:
            self.colors = np.array([mc.to_rgba(e) for e in colors])

        # self.pxiels is rgba representation
        data_256 = rescale(self.data, scale=20, order=0).astype(int)
        data_256[data_256 < 0] = 0

        self.pixels = (self.colors[data_256] * 255).astype(np.uint8)
        # add grids
        self.pixels = add_grids(self.pixels)
        # make white pixel transparent
        alpha = ~np.all(self.pixels == 255, axis=2) * 255
        self.pixels[:, :, 3] = alpha

        self.cells = cells
        self.pos = pos
        self.lbs = lbs

    def _repr_png_(self):
        """Generate a PNG representation of the ConnectedComponents."""
        #data_256 = rescale(self.data, scale=20, order=0).astype(int)
        #data_256[data_256 < 0] = 0

        #pixels = (self.colors[data_256] * 255).astype(np.uint8)
        # code from matplotlib.colors
        png_bytes = io.BytesIO()
        title = 'level-{}'.format(self.num)
        pnginfo = PngInfo()
        pnginfo.add_text('Title', title)
        pnginfo.add_text('Description', title)
        Image.fromarray(self.pixels).save(png_bytes, format='png', pnginfo=pnginfo)
        return png_bytes.getvalue()

    def _repr_html_(self):
        """Generate an HTML representation of the ConnectedComponent."""
        png_bytes = self._repr_png_()
        png_base64 = base64.b64encode(png_bytes).decode('ascii')

        return ('<div style="vertical-align: middle;">'
                f'<strong>{self.num}</strong> '
                '</div>'
                '<div class="cmap"><img '
                f'alt="{self.num} colormap" '
                f'title="{self.num}" '
                'style="border: 1px solid #555;" '
                f'src="data:image/png;base64,{png_base64}"></div>'
                '<div style="vertical-align: middle; '
                f'max-width: {258}px; '
                'display: flex; justify-content: space-between;">')

    def rot90(self, k=1):
        return ConnectedCells(data=np.rot90(self.data, k=k))

    def _contains(self, c):
        def extract_patches(data, patch_shape=64, extraction_step=1):
            data_ndim = data.ndim
            # if patch_shape is a number, turn it into tuple
            if isinstance(patch_shape, numbers.Number):
                patch_shape = tuple([patch_shape] * data_ndim)

            patch_strides = data.strides

            # Extract all patches setting extraction_step to 1
            slices = tuple([slice(None, None, st) for st in (1, 1)])
            indexing_strides = data[slices].strides

            patch_indices_shape = (np.array(data.shape) - np.array(patch_shape)) + 1

            shape = tuple(list(patch_indices_shape) + list(patch_shape))
            strides = tuple(list(indexing_strides) + list(patch_strides))
            # Using strides and shape to get a 4d numpy array
            patches = as_strided(data, shape=shape, strides=strides)
            return patches

        if min(self.w, self.h) >= max(c.w, c.h):
            patches = extract_patches(self.data, c.shape)
            patches_ = patches * c.mask[np.newaxis, np.newaxis, :]
            s1 = ((patches_ == c.data).all(axis=(2, 3))).any()
            return s1
        else:
            return False

    def contains(self, c):
        c1 = c.rot90(k=1)
        c2 = c.rot90(k=2)
        c3 = c.rot90(k=3)

        s = self._contains(c)
        s1 = self._contains(c1)
        s2 = self._contains(c2)
        s3 = self._contains(c3)
        return np.any([s, s1, s2, s3])

    def plot(self, ax=None, colors=None, **kwargs):
        colors = colors_from_lbs(self.lbs, colors)

        if ax is None:
            fig, ax = plt.subplots(1, 1, figsize=(7.2, 7.2))

        for cell in self.cells:
            poly = Polygon(cell, ec='k', fc='none')
            ax.add_patch(poly)
        ax.scatter(self.pos[:, 0], self.pos[:, 1], c=colors, **kwargs)
        ax.axis('equal')

In [21]:
aa = HexCells(pp)
aa.set_cell_lbs(pts, lbs)

In [22]:
fig, ax = plt.subplots(1, 1, figsize=(7.2, 7.2))
ax.imshow(img, cmap='gray')
aa.plot(ax)

In [23]:
colors = ['white'] + ['C3', 'C1', 'C2', 'C0', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9'] + ['gray']
colors = None

In [24]:
cs = aa.get_connected_cells(min_area=3, max_area=6, colors=None)

In [25]:
bb = HGraph(cs)
for ii in [3, 4, 5, 6, 7]:
    bb.update_pos(ii, shift=20)
for ii in [8, 9, 10]:
    bb.update_pos(ii, shift=-20)
bb.update_pos(21, shift=1)

In [26]:
bb.plot(zoom=0.010)

In [27]:
nu = aa.cnts
deno = aa.cnts.sum()

In [28]:
from matplotlib.patches import Polygon, Rectangle

def fig_add_diamond(fig, x, y, a, **kwargs):
    fig2pixel = fig.transFigure
    x, y = fig2pixel.transform([x, y])
    a = fig2pixel.transform([a, a])[0]
    
    p4 = np.array([(0, 0), (np.sqrt(3)*a/2, -a/2), (np.sqrt(3)*a, 0), (np.sqrt(3)*a/2, +a/2)])
    p4[:, 0] += x
    p4[:, 1] += y
    diamond = Polygon(xy=p4, transform=None, **kwargs)
    fig.add_artist(diamond)
    return diamond

def fig_add_square(fig, x, y, a, **kwargs):
    fig2pixel = fig.transFigure
    x, y = fig2pixel.transform([x, y])
    a = fig2pixel.transform([a, a])[0]
    rect = Rectangle((x-a/2, y-a/2), a, a, transform=None, **kwargs)
    fig.add_artist(rect)
    return rect

def add_dots(fig, poly, r=10, c1='C0', c2='C1'):
    p1, p2, p3, p4 = poly.get_xy()[:-1]
    l = p3[0] - p1[0]
    x, y = p1
    pp1 = (x+l/3, y)
    pp2 = (x+l/3*2, y)
    c1 = plt.Circle(pp1, r, fc=c1, transform=None)
    c2 = plt.Circle(pp2, r, fc=c2, transform=None)
    fig.add_artist(c1)
    fig.add_artist(c2)
    

def ax_add_diamond(ax, x, y, a, **kwargs):
    data2pixel = ax.transData
    x, y = data2pixel.transform([x, y])
    q1 = data2pixel.transform([0, 0])
    q2 = data2pixel.transform([0, a])
    a = np.abs(q2-q1)[1]

    p4 = np.array([(0, 0), (np.sqrt(3)*a/2, -a/2), (np.sqrt(3)*a, 0), (np.sqrt(3)*a/2, +a/2)])
    p4 = p4 - p4.mean(axis=0)
    p4[:, 0] += x
    p4[:, 1] += y
    
    pixel2data = ax.transData.inverted()
    p4 = pixel2data.transform(p4)
    diamond = Polygon(xy=p4, transform=ax.transData, **kwargs)
    ax.add_patch(diamond)
    return diamond

def ax_add_square(ax, x, y, a, **kwargs):
    data2pixel = ax.transData
    x, y = data2pixel.transform([x, y])
    q1 = data2pixel.transform([0, 0])
    q2 = data2pixel.transform([0, a])
    a = np.abs(q2-q1)[1]
    
    p4 = np.array([(0, 0), (0, a), (a, a), (a, 0)])
    p4 = p4 - p4.mean(axis=0)
    p4[:, 0] += x
    p4[:, 1] += y
    
    pixel2data = ax.transData.inverted()
    p4 = pixel2data.transform(p4)
    rect = Polygon(xy=p4, transform=ax.transData, **kwargs)
    ax.add_patch(rect)
    return rect

def ax_add_dots(ax, poly, r=10, c1='C0', c2='C1', **kwargs):
    p1, p2, p3, p4 = poly.get_xy()[:-1]
    l = p3[0] - p1[0]
    x, y = p1
    pp1 = (x+l/3, y)
    pp2 = (x+l/3*2, y)
    ax.scatter(pp1[0], pp1[1], c=c1, **kwargs)
    ax.scatter(pp2[0], pp2[1], c=c2, **kwargs)

    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)  
    
    
def plot_block_legend(ax, box_color='C0', c1='C0', c2='C1', **kwargs):
    ratio = get_ax_aspect(ax)
    diamond = ax_add_diamond(ax, 0.7, 0.5, 0.2*ratio, fc='none', ec='#2d3742')
    box = ax_add_square(ax, 0.3, 0.5, 0.2*ratio, fc=box_color)
    ax_add_dots(ax, diamond, c1=c1, c2=c2, **kwargs)
    ax.text(0.45, 0.5, s=r':', va='center', ha='center', fontsize=14, fontweight='bold', color='#2d3742')
    ax.axis('off')

In [29]:
def plot_cell_image(ax, aa):
    img = aa.image.copy().astype(int)
    img[img==-1] = aa.cell_lbs.max()+1
    # convert img to rgba
    colors_array = to_rgba(['#cee4cc'] + ['C{}'.format(ii) for ii in range(0, 10)]+['white','white'])
    rgba = colors_array[img]
    ax.imshow(rgba)

    ax.set_xticks(np.arange(img.shape[1]+1)-.5, minor=True)
    ax.set_yticks(np.arange(img.shape[0]+1)-.5, minor=True)
    ax.tick_params(axis='both', which='minor', length=0)
    ax.grid(which="minor", color="w", linestyle='-', linewidth=0.4)
    ax.set_xticks([])
    ax.set_yticks([])
    
    ax.set_ylim(47.5, -0.5)  
    
    
def plot_hierarchy(ax, bb, zoom=0.006):
    bb.plot(ax, zoom=zoom)
    ax.set_yticks([0,1,2,3])
    ax.set_yticklabels(['3', '4', '5', '6'], fontsize=6)
    
    fontsize=8
    for ii, (x, y) in enumerate(bb.pos):
        if ii in [0, 1, 2]:
            ax.text(x, y-0.2, s=r'$\frac{{{}}}{{{}}}$'.format(nu[ii], deno), fontsize=fontsize, ha='center')
        elif ii == 4:
            ax.text(x+0.5, y+0.2, s=r'$\frac{{{}}}{{{}}}$'.format(nu[ii], deno), fontsize=fontsize, ha='center')
        else:
            ax.text(x, y+0.2, s=r'$\frac{{{}}}{{{}}}$'.format(nu[ii], deno), fontsize=fontsize, ha='center')

In [30]:
fig = plt.figure(figsize=(7.2, 7.2/2))
ax1, ax2 = h_axes(fig, n=2, top=0.9, left=1/15, right=1/30, wspace=0.25, ratios=[1, 1.3], h=0)
plot_cell_image(ax1, aa)
plot_hierarchy(ax2, bb, zoom=0.006)
ax2.set_ylabel('Levels (trucated)')

top = get_top_from_axes([ax1, ax2], hspace=0.1)
axes2 = h_axes(fig, n=4, top=top, left=1/15, right=1/30, wspace=0.25, ratios=1, h=0.5)
for ii, ax in enumerate(axes2):
    if ii == 0:
        plot_block_legend(ax, box_color='C{}'.format(ii), s=8, c1='C4', c2='C1')
    elif ii == 1:
        plot_block_legend(ax, box_color='C{}'.format(ii), s=8, c1='C2', c2='C0')
    elif ii == 2:
        plot_block_legend(ax, box_color='C{}'.format(ii), s=8, c1='C9', c2='C4')
    elif ii == 3:
        plot_block_legend(ax, box_color='C{}'.format(ii), s=8, c1='C3', c2='C2')

ax = fig.add_axes([0, 0, 1 , 1], fc=[0, 0, 0, 0])
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)  

ax.axis('off') 

ax.text(x=0.0139, y=0.9532, s='A', ha='left', va='top', transform=ax.transAxes, fontweight='bold')
ax.text(x=0.4806, y=0.9532, s='B', ha='left', va='top', transform=ax.transAxes, fontweight='bold')
#Cursor(ax)

Text(0.4806, 0.9532, 'B')

In [207]:
fig.savefig('Hierarchy_WS2.png', dpi=500, transparent=True) 