In [1]:
import numpy as np
import matplotlib
matplotlib.rc('figure', dpi=190)
matplotlib.rc('image', cmap='gray')
import matplotlib.pyplot as plt
import cv2
from PIL import Image

In [2]:
def make_dendrogram(img):
    img = img - img.min()
    if img.max() != 0:
        img = img / img.max()
    im = np.ones((img.shape[0]+2, img.shape[1]+2))
    im[1:-1, 1:-1] = img
    im = np.round(im * 255)

    visited = np.ones(im.shape)
    visited[1:-1, 1:-1] = 0
    
    ind = np.unravel_index(np.argmax(im), im.shape)
    root, _ = dendrogram(im[ind], ind, im, visited)
    return Dendrogram(root, im[1:-1, 1:-1])
    
def find_min(start, im):
    visited = np.ones(im.shape)
    visited[1:-1, 1:-1] = 0
    visited[start] = 1
    cur = im[start]
    min_ind = start
    bfs_queue = {*neighbors(*start)}
    while bfs_queue:
        r, c = bfs_queue.pop()
        if visited[r, c]:
            continue
        visited[r, c] = 1            
        if im[r, c] < cur:
            cur = im[r, c]
            min_ind = (r, c)
            bfs_queue = {*neighbors(r, c)}
        if im[r, c] == cur:
            for n in neighbors(r, c):
                if not visited[n]:
                    bfs_queue.add(n)
    return min_ind, cur
    
def neighbors(r, c):
    return [(r-1, c), (r+1, c), (r, c-1), (r, c+1),
            (r-1, c-1), (r-1, c+1), (r+1, c-1), (r+1, c+1)]
    
def dendrogram(toplevel, start, im, visited):
    (r, c), level = find_min(start, im)
    d = Dendro(level)
    
    visited[r, c] = 1
    accrete_history = [{(r-1, c-1)}]
    bfs_queue = {*neighbors(r, c)}
    while level <= toplevel:
        failed = set()
        accreted = set()
        siblings = []
        while bfs_queue:
            r, c = bfs_queue.pop()
            if visited[r, c]:
                continue
            if im[r, c] <= level:
                visited[r, c] = 1
                accreted.add((r-1, c-1))
                if im[r, c] < level:
                    sibling, surround = dendrogram(level-1, (r, c), im, visited)
                    bfs_queue = bfs_queue.difference(sibling.full_region)
                    bfs_queue = bfs_queue.union(surround)
                    siblings.append(sibling)
                for n in neighbors(r, c):
                    if not visited[n]:
                        bfs_queue.add(n)
            else:
                failed.add((r, c))
                
        if siblings:
            next_d = Dendro(level)
            d.initialize(accrete_history, level)
            accrete_history = []
            d.parent = next_d
            next_d.children.append(d)
            for s in siblings:
                s.parent = next_d
                next_d.children.append(s)
            d = next_d
            
        accrete_history.append(accreted)
        bfs_queue = failed
        level += 1
    
    d.initialize(accrete_history, toplevel)
    return d, bfs_queue

In [1]:
class Dendrogram:
    
    def __init__(self, root, im):
        self.root = root      
        if im is not None:
            root.calculate(im)
        self.branches = root.list_descendants()
        self.size = len(self.branches)
        for tid in range(len(self.branches)):
            self.branches[tid].tree_id = tid
        self.subtrees = np.zeros((self.size, self.size))
        for tid in range(len(self.branches)):
            subtree = self.branches[tid].list_descendants()
            subtree_tids = [b.tree_id for b in subtree]
            for stid in subtree_tids:
                self.subtrees[tid, stid] = 1
        
    def draw(self, img):
        self.root.color()
        colors = [(255, 231, 76),
                  (83, 221, 108),
                  (99, 160, 136),
                  (86, 99, 138),
                  (72, 58, 88),
                  (86, 32, 61),
                  (80, 10, 12),
                  (0, 0, 0)]
        self.root.draw(img, colors)

class Dendro:
    
    def __init__(self, level):
        self.level = level
        self.children = []
        self.parent = None
        self.history = None
        self.region = None
        self.vol = None
        self.mass = None
        self.loc = None
        self.tree_id = None
        self.traj_id = None
        
    def initialize(self, accrete_hist, toplevel):
        # important: children are initialized first
        self.history = accrete_hist
        self.toplevel = toplevel
        self.region = set().union(*self.history)
        self.full_region = self.region.union(*[c.full_region for c in self.children])
        self.vol = len(self.full_region)
        
    def calculate(self, im):
        def px_to_coord(px):
            x = px[1] - im.shape[1]//2
            y = im.shape[0]//2 - px[0]
            return np.array([x, y])
        
        for c in self.children:
            c.calculate(im)
            
        subtree_mass = 0
        mass = 0
        loc = np.array([0., 0.])
        for c in self.children:
            subtree_mass += c.mass
            loc += c.loc * c.mass
            depth = self.toplevel - self.level
            child_region = [px_to_coord(px) for px in c.full_region]
            mass += len(child_region) * depth
            loc += np.sum(child_region, axis=0) * depth
        for px in self.region:
            depth = self.toplevel - im[px]
            mass += depth
            loc += depth * px_to_coord(px)
        self.exclusive_mass = mass
        self.mass = self.exclusive_mass + subtree_mass
        loc /= self.mass
        self.loc = loc
        self.mass_frac = self.exclusive_mass / self.mass
        sigmoid = lambda x: 1. / (1. + np.exp((0.5-x)/0.1))
        self.weight = sigmoid(self.mass_frac)
    
    def list_descendants(self):
        descendants = [self]
        for c in self.children:
            descendants += c.list_descendants()
        return descendants
    
    def color(self):
        if not self.children:
            self.color = 0
        else:
            self.color = 1 + max([c.color() for c in self.children])
        return self.color
        
    def draw(self, img, colors):
        col = colors[self.color if self.color<8 else 7]
        for r, c in self.region:
            img.putpixel((c, r), col)
        for c in self.children:
            c.draw(img, colors)