In [None]:
%cd /content/
!rm -rf outflow-tracking
!git clone https://github.com/dkarkada/outflow-tracking
%cd outflow-tracking

import matplotlib.pyplot as plt
import matplotlib._color_data as mcd
import matplotlib.colors as mplc

plt.rc('figure', dpi=100)

from jax.nn import softmax
from jax import grad, jit
import jax.numpy as jnp
from jax.experimental import optimizers

import numpy as np
import cv2
import pickle
   
from IPython.display import display
import pandas as pd

RECALCULATE = True

/content
Cloning into 'outflow-tracking'...
remote: Enumerating objects: 333, done.[K
remote: Counting objects: 100% (50/50), done.[K
remote: Compressing objects: 100% (25/25), done.[K
remote: Total 333 (delta 27), reused 47 (delta 25), pack-reused 283[K
Receiving objects: 100% (333/333), 177.63 MiB | 25.90 MiB/s, done.
Resolving deltas: 100% (77/77), done.
/content/outflow-tracking


# Dendrogram and Annealer

Custom dendrogram implementation for continuous optimization.

Simulated annealer comes from https://github.com/perrygeo/simanneal. Defines the Annealer class, which includes the implementation for the annealing algorithm. You just have to extend the class and implement the abstract methods `move` (which generates a new state) and `energy` (which computes the energy associated with the current state).

In [None]:
class Dendrogram:

    def make_recurse(self, start, top, visited):
        """
        Recursive function for constructing dendrogram from a square, 2D image.
        Params:
            start ((int, int)): index of starting point to find nearest local
                min
            top: (int): the elevation of the top of the current basin
            visited (boolean ndarray): same size as img. Tracks which pixels
                have been visited
        """

        within = lambda r, c: (0 <= r < im.shape[0]) and (0 <= c < im.shape[1])

        def neighbors(r, c):
            return [(r-1, c), (r+1, c), (r, c-1), (r, c+1),
                    (r-1, c-1), (r-1, c+1), (r+1, c-1), (r+1, c+1)]
        
        def find_min(start):
            im = self.img
            visited = np.zeros(im.shape)
            cur = im[start]
            min_ind = start
            bfs_queue = {start}
            while bfs_queue:
                r, c = bfs_queue.pop()
                if not within(r, c) or visited[r, c]:
                    continue
                visited[r, c] = 1
                if im[r, c] < cur:
                    cur = im[r, c]
                    min_ind = (r, c)
                    bfs_queue = {*neighbors(r, c)}
                if im[r, c] == cur:
                    for n in neighbors(r, c):
                        if within(*n) and not visited[n]:
                            bfs_queue.add(n)
            return min_ind, cur
        
        im = self.img
        (r, c), bottom = find_min(start)
        level = bottom
        d = Branch(self)
        
        map = d.map
        bfs_queue = {(r, c)}
        while level < top:
            failed = set()
            siblings = []
            while bfs_queue:
                r, c = bfs_queue.pop()
                if not within(r, c) or visited[r, c]:
                    continue
                if im[r, c] == level:
                    visited[r, c] = 1
                    map[r, c] = 1
                    for n in neighbors(r, c):
                        if within(*n) and not visited[n]:
                            bfs_queue.add(n)
                elif im[r, c] < level:
                    sibling, surround = self.make_recurse((r, c),
                                                          level, visited)
                    bfs_queue = bfs_queue.union(surround)
                    siblings.append(sibling)
                else:
                    failed.add((r, c))
            
            if siblings:
                next_d = Branch(self)
                siblings.append(d)
                for s in siblings:
                    s.parent = next_d
                    next_d.children.append(s)
                    next_d.map[s.map>0] = 1
                map[map>0] = level - np.maximum(im[map>0], bottom)
                d = next_d
                map = d.map
                bottom = level
            
            level += 1
            bfs_queue = failed
        
        root_correction = 1 if top==256 else 0
        map[map>0] = top - np.maximum(im[map>0], bottom) - root_correction
        return d, bfs_queue
    
    def __init__(self, img):
        """
        Construct a dendrogram from a square, 2D image.
        Params:
            img (ndarray): The grayscale (1-channel) image
        """
        img = img - img.min()
        if img.max() != 0:
            img = img / img.max()
        img = np.round(img * 255)
        self.img = img
        
        h, w = self.img.shape
        coords = np.meshgrid(np.linspace(-w/2, w/2, w, endpoint=False),
                             np.linspace(h/2, -h/2, h, endpoint=False))
        self.coords = np.stack(coords)

        visited = np.zeros(img.shape)
        ind = np.unravel_index(np.argmax(img), img.shape)
        root, _ = self.make_recurse(ind, 256, visited)
        self.root = root

        root.merge()
        root.calculate()
        self.branches = root.descendants
        self.N = len(self.branches)
        for n, b in enumerate(self.branches):
            b.id = n
        
        self.metric = np.zeros((self.N, self.N))
        self.compute_metric(root)
        
        self.hierarchy = np.zeros((self.N, self.N))
        for branch_id in range(self.N):
            twigs = self.branches[branch_id].descendants
            for twig_id in [t.id for t in twigs]:
                self.hierarchy[branch_id, twig_id] = 1
        
        mass_arr = [b.mass for b in self.branches]
        # append source mass
        self.masses = np.array(mass_arr + [max(mass_arr)])
        self.x = np.array([b.x for b in self.branches])
        
    def compute_metric(self, branch):
        """
        Recursively compute tree metric for a dendrogram. Stores
        as an nxn matrix in self.metric, n = num branches
        Params:
            branch (Branch): The current branch
        """
        if len(branch.children) == 0:
            return
        
        subtrees = []
        for c in branch.children:
            self.compute_metric(c)
            subtree = c.descendants
            for b in subtree:
                dist = 1 + self.metric[c.id, b.id]
                assert self.metric[branch.id, b.id] == 0
                self.metric[branch.id, b.id] = dist
                self.metric[b.id, branch.id] = dist
            subtrees.append((c, subtree))
            
        for c1_ind in range(len(subtrees)):
            for c2_ind in range(c1_ind+1, len(subtrees)):
                c1, subtree1 = subtrees[c1_ind]
                c2, subtree2 = subtrees[c2_ind]
                for b1 in subtree1:
                    for b2 in subtree2:
                        dist = 2 + self.metric[c1.id, b1.id] \
                                 + self.metric[c2.id, b2.id]
                        assert self.metric[b1.id, b2.id] == 0
                        self.metric[b1.id, b2.id] = dist
                        self.metric[b2.id, b1.id] = dist

    def print(self):
        dfs_stack = [(self.root, 0)]
        while len(dfs_stack) > 0:
            b, level = dfs_stack.pop()
            print("{}{}".format(level*"  ", b.id))
            for c in b.children:
                dfs_stack.append((c, level+1))
            

class Branch:
    
    def __init__(self, dendrogram):
        self.dendrogram = dendrogram
        self.map = np.zeros(dendrogram.img.shape)

        self.children = []
        self.descendants = [self]
        self.parent = None
        
        self.mass = None
        self.total_mass = None
        self.mass_frac = None
        self.x = None
        self.covariance = None
        
        self.id = None
    
    def merge(self, depth_thresh=8):
        """
        Simplify dendrogram sub-tree by merging shallow children branches
        into parent branch
        Params:
            depth_thresh (int): The minimum branch depth to avoid merging
        """
        def do_merge(c):
            self.map += c.map
            new_children = [grandchild for grandchild in c.children]
            for new_child in new_children:
                new_child.parent = self
            return new_children

        new_children = []
        for c in self.children:
            c.merge(depth_thresh)                
            if np.max(c.map) < depth_thresh:
                new_children += do_merge(c)
            else:
                new_children.append(c)
        self.children = new_children
        
        if len(self.children) == 1:
            c = self.children.pop()
            self.children = do_merge(c)
        
        for c in self.children:
            self.descendants += c.descendants
            
    def calculate(self):
        """
        Calculate properties (mass, total enclosed mass, center-of-mass, and 
        spatial covariance matrix) for each branch in the sub-tree.
        """
        for c in self.children:
            c.calculate()
            
        self.mass = self.map.sum()
        # COMPLETE HACK
        self.mass = max(1, self.mass)
        self.total_mass = self.mass + sum([c.mass for c in self.children])
        self.mass_frac = self.mass / self.total_mass
        # center of mass
        x_vec = np.einsum('ijk,jk->i', self.dendrogram.coords, self.map)
        self.x = x_vec / self.mass
        
        # covariance matrix calculation
        dx_coords = self.dendrogram.coords - self.x[:, None, None]
        outers = np.einsum('ijk,ljk->iljk', dx_coords, dx_coords)
        cov_mat = np.einsum('iljk,jk->il', outers, self.map)
        self.covariance = cov_mat / self.mass

In [None]:
def plot_gaussian_estimate(dendro):
    """
    Given a dendrogram, display the branches in real space as Gaussian
    covariance ellipses.
    """
    def get_ellipse(cov):
        a, b, c = cov[0,0], cov[0,1], cov[1,1]
        c1 = (a+c)/2
        c2 = np.linalg.norm([b, (a-c)/2])
        l1 = c1 + c2
        l2 = c1 - c2
        assert l1 >= 0
        assert l2 >= 0
        theta = 0
        if b == 0 and a < c:
            theta = np.pi/2
        if b != 0:
            theta = np.arctan2(l1-a, b)
        t = np.linspace(0, 2*np.pi, 30)
        x = np.sqrt(6*l1)*np.cos(-theta)*np.cos(t) - np.sqrt(6*l2)*np.sin(-theta)*np.sin(t)
        y = np.sqrt(6*l1)*np.sin(-theta)*np.cos(t) + np.sqrt(6*l2)*np.cos(-theta)*np.sin(t)
        return x, y

    plt.imshow(dendro.img, cmap=plt.cm.Greys_r)
    plt.colorbar()
    for b in dendro.branches:
        x, y = get_ellipse(b.covariance)
        mu_x = b.x[0] + b.map.shape[1]//2
        mu_y = b.map.shape[0]//2 - b.x[1]
        x, y = x + mu_x, y + mu_y
        plt.plot(x, y, color=(.7,.3,.3,b.mass_frac**3))
    plt.show()

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import abc
import copy
import datetime
import math
import pickle
import random
import signal
import sys
import time


def round_figures(x, n):
    """Returns x rounded to n significant figures."""
    return round(x, int(n - math.ceil(math.log10(abs(x)))))


def time_string(seconds):
    """Returns time in seconds as a string formatted HHHH:MM:SS."""
    s = int(round(seconds))  # round to nearest second
    h, s = divmod(s, 3600)   # get hours and remainder
    m, s = divmod(s, 60)     # split remainder into minutes and seconds
    return '%4i:%02i:%02i' % (h, m, s)


class Annealer(object):

    """Performs simulated annealing by calling functions to calculate
    energy and make moves on a state.  The temperature schedule for
    annealing may be provided manually or estimated automatically.
    """

    __metaclass__ = abc.ABCMeta

    # defaults
    Tmax = 25000.0
    Tmin = 2.5
    steps = 50000
    updates = 100
    copy_strategy = 'deepcopy'
    user_exit = False
    save_state_on_exit = False

    # placeholders
    best_state = None
    best_energy = None
    start = None

    def __init__(self, initial_state=None, load_state=None):
        if initial_state is not None:
            self.state = self.copy_state(initial_state)
        elif load_state:
            self.load_state(load_state)
        else:
            raise ValueError('No valid values supplied for neither \
            initial_state nor load_state')

        signal.signal(signal.SIGINT, self.set_user_exit)

    def save_state(self, fname=None):
        """Saves state to pickle"""
        if not fname:
            date = datetime.datetime.now().strftime("%Y-%m-%dT%Hh%Mm%Ss")
            fname = date + "_energy_" + str(self.energy()) + ".state"
        with open(fname, "wb") as fh:
            pickle.dump(self.state, fh)

    def load_state(self, fname=None):
        """Loads state from pickle"""
        with open(fname, 'rb') as fh:
            self.state = pickle.load(fh)

    @abc.abstractmethod
    def move(self):
        """Create a state change"""
        pass

    @abc.abstractmethod
    def energy(self):
        """Calculate state's energy"""
        pass

    def set_user_exit(self, signum, frame):
        """Raises the user_exit flag, further iterations are stopped
        """
        self.user_exit = True

    def set_schedule(self, schedule):
        """Takes the output from `auto` and sets the attributes
        """
        self.Tmax = schedule['tmax']
        self.Tmin = schedule['tmin']
        self.steps = int(schedule['steps'])
        self.updates = int(schedule['updates'])

    def copy_state(self, state):
        """Returns an exact copy of the provided state
        Implemented according to self.copy_strategy, one of
        * deepcopy: use copy.deepcopy (slow but reliable)
        * slice: use list slices (faster but only works if state is list-like)
        * method: use the state's copy() method
        """
        if self.copy_strategy == 'deepcopy':
            return copy.deepcopy(state)
        elif self.copy_strategy == 'slice':
            return state[:]
        elif self.copy_strategy == 'method':
            return state.copy()
        else:
            raise RuntimeError('No implementation found for ' +
                               'the self.copy_strategy "%s"' %
                               self.copy_strategy)

    def update(self, *args, **kwargs):
        """Wrapper for internal update.
        If you override the self.update method,
        you can chose to call the self.default_update method
        from your own Annealer.
        """
        self.default_update(*args, **kwargs)

    def default_update(self, step, T, E, acceptance, improvement):
        """Default update, outputs to stderr.
        Prints the current temperature, energy, acceptance rate,
        improvement rate, elapsed time, and remaining time.
        The acceptance rate indicates the percentage of moves since the last
        update that were accepted by the Metropolis algorithm.  It includes
        moves that decreased the energy, moves that left the energy
        unchanged, and moves that increased the energy yet were reached by
        thermal excitation.
        The improvement rate indicates the percentage of moves since the
        last update that strictly decreased the energy.  At high
        temperatures it will include both moves that improved the overall
        state and moves that simply undid previously accepted moves that
        increased the energy by thermal excititation.  At low temperatures
        it will tend toward zero as the moves that can decrease the energy
        are exhausted and moves that would increase the energy are no longer
        thermally accessible."""

        elapsed = time.time() - self.start
        if step == 0:
            print('\n Temperature        Energy    Accept   Improve     Elapsed   Remaining',
                  file=sys.stderr)
            print('\r{Temp:12.5f}  {Energy:12.2f}                      {Elapsed:s}            '
                  .format(Temp=T,
                          Energy=E,
                          Elapsed=time_string(elapsed)),
                  file=sys.stderr, end="")
            sys.stderr.flush()
        else:
            remain = (self.steps - step) * (elapsed / step)
            print('\r{Temp:12.5f}  {Energy:12.2f}   {Accept:7.2%}   {Improve:7.2%}  {Elapsed:s}  {Remaining:s}'
                  .format(Temp=T,
                          Energy=E,
                          Accept=acceptance,
                          Improve=improvement,
                          Elapsed=time_string(elapsed),
                          Remaining=time_string(remain)),
                  file=sys.stderr, end="")
            sys.stderr.flush()

    def anneal(self):
        """Minimizes the energy of a system by simulated annealing.
        Parameters
        state : an initial arrangement of the system
        Returns
        (state, energy): the best state and energy found.
        """
        step = 0
        self.start = time.time()

        # Precompute factor for exponential cooling from Tmax to Tmin
        if self.Tmin <= 0.0:
            raise Exception('Exponential cooling requires a minimum "\
                "temperature greater than zero.')
        Tfactor = -math.log(self.Tmax / self.Tmin)

        # Note initial state
        T = self.Tmax
        E = self.energy()
        prevState = self.copy_state(self.state)
        prevEnergy = E
        self.best_state = self.copy_state(self.state)
        self.best_energy = E
        trials = accepts = improves = 0
        if self.updates > 0:
            updateWavelength = self.steps / self.updates
            self.update(step, T, E, None, None)

        # Attempt moves to new states
        while step < self.steps and not self.user_exit:
            step += 1
            T = self.Tmax * math.exp(Tfactor * step / self.steps)
            dE = self.move()
            if dE is None:
                E = self.energy()
                dE = E - prevEnergy
            else:
                E += dE
            trials += 1
            if dE > 0.0 and math.exp(-dE / T) < random.random():
                # Restore previous state
                self.state = self.copy_state(prevState)
                E = prevEnergy
            else:
                # Accept new state and compare to best state
                accepts += 1
                if dE < 0.0:
                    improves += 1
                prevState = self.copy_state(self.state)
                prevEnergy = E
                if E < self.best_energy:
                    self.best_state = self.copy_state(self.state)
                    self.best_energy = E
            if self.updates > 1:
                if (step // updateWavelength) > ((step - 1) // updateWavelength):
                    self.update(
                        step, T, E, accepts / trials, improves / trials)
                    trials = accepts = improves = 0

        self.state = self.copy_state(self.best_state)
        if self.save_state_on_exit:
            self.save_state()

        # Return best state and energy
        return self.best_state, self.best_energy

    def auto(self, minutes, steps=2000):
        """Explores the annealing landscape and
        estimates optimal temperature settings.
        Returns a dictionary suitable for the `set_schedule` method.
        """

        def run(T, steps):
            """Anneals a system at constant temperature and returns the state,
            energy, rate of acceptance, and rate of improvement."""
            E = self.energy()
            prevState = self.copy_state(self.state)
            prevEnergy = E
            accepts, improves = 0, 0
            for _ in range(steps):
                dE = self.move()
                if dE is None:
                    E = self.energy()
                    dE = E - prevEnergy
                else:
                    E = prevEnergy + dE
                if dE > 0.0 and math.exp(-dE / T) < random.random():
                    self.state = self.copy_state(prevState)
                    E = prevEnergy
                else:
                    accepts += 1
                    if dE < 0.0:
                        improves += 1
                    prevState = self.copy_state(self.state)
                    prevEnergy = E
            return E, float(accepts) / steps, float(improves) / steps

        step = 0
        self.start = time.time()

        # Attempting automatic simulated anneal...
        # Find an initial guess for temperature
        T = 0.0
        E = self.energy()
        self.update(step, T, E, None, None)
        while T == 0.0:
            step += 1
            dE = self.move()
            if dE is None:
                dE = self.energy() - E
            T = abs(dE)

        # Search for Tmax - a temperature that gives 98% acceptance
        E, acceptance, improvement = run(T, steps)

        step += steps
        while acceptance > 0.98:
            T = round_figures(T / 1.5, 2)
            E, acceptance, improvement = run(T, steps)
            step += steps
            self.update(step, T, E, acceptance, improvement)
        while acceptance < 0.98:
            T = round_figures(T * 1.5, 2)
            E, acceptance, improvement = run(T, steps)
            step += steps
            self.update(step, T, E, acceptance, improvement)
        Tmax = T

        # Search for Tmin - a temperature that gives 0% improvement
        while improvement > 0.0:
            T = round_figures(T / 1.5, 2)
            E, acceptance, improvement = run(T, steps)
            step += steps
            self.update(step, T, E, acceptance, improvement)
        Tmin = T

        # Calculate anneal duration
        elapsed = time.time() - self.start
        duration = round_figures(int(60.0 * minutes * step / elapsed), 2)

        # Don't perform anneal, just return params
        return {'tmax': Tmax, 'tmin': Tmin, 'steps': duration, 'updates': self.updates}

In [None]:
from collections import namedtuple
BranchData = namedtuple('BranchData', 'id mass total_mass x')

class SimpleNode:
    """
    A lightweight container for dendrogram data, to efficiently run
    simulated annealing. Since the annealer uses a deep-copy strategy by
    default, it's slow to run annealing on the full dendrogram data 
    structure (which is data-heavy).
    """

    def __init__(self, branch, children):
        self.data = BranchData(branch.id, branch.mass,
                               branch.total_mass, branch.x,)
        self.children = children
        self.parent = None
        for c in self.children:
            c.parent  = self
        self.link = None
        self.backlinks = set()
        
        self.level = None
        self.descendants = None
        self.desc_prob = None

    def calc_desc(self, level=0):
        self.level = level
        descendants = [self]

        for c in self.children:
            c_desc = c.calc_desc(level+1)
            descendants = descendants + c_desc
        self.descendants = descendants
        desc_dist = np.array([d.level - self.level for d in descendants])
        # base 3 because it should be >2 since tree is ~binary and
        # we want to weight the first level greater
        desc_prob = 3.0**(-np.abs(desc_dist - 1))
        self.desc_prob = desc_prob / np.sum(desc_prob)
        return descendants

    def switch(self, rng):
        if self.link is None:
            print("Must link first.")
            return
        
        probs = self.link.desc_prob
        matchable = self.link.descendants
        for c in self.children:
            if c.link is not None:
                c.link.backlinks.remove(c)
            c.link = rng.choice(matchable, p=probs)
            c.link.backlinks.add(c)
            c.switch(rng)

    def print(self):
        if self.link is None:
            print("Can't print this tree.")
            return
        print("{}{}-{}".format(self.level*" ",
                               self.data.id,
                               self.link.data.id))
        for c in self.children:
            c.print()


def branch_to_tree(b):
    children = []
    for c in b.children:
        children.append(branch_to_tree(c))
    return SimpleNode(b, children)


class DendroMatch(Annealer):

    rng = np.random.default_rng()

    def move(self):
        rng = DendroMatch.rng
        parent = rng.choice(self.state.descendants)
        parent.switch(rng)

    def energy(self):
        t0, t1 = self.state, self.state.link
        cost = 0
        for d1 in t1.descendants:
            if len(d1.backlinks) == 0:
                cost += d1.data.mass
                continue
            m_j = d1.data.total_mass
            x_j = d1.data.x
            
            m_in = np.array([d0.data.total_mass for d0 in d1.backlinks])
            x_in = np.array([d0.data.x for d0 in d1.backlinks])
            
            m_i = m_in.sum()
            x_i = np.einsum('i,ij->j', m_in/m_i, x_in)

            cost_m = (m_i/m_j + m_j/m_i - 2)**2
            cost_x = np.linalg.norm(x_i - x_j)**2
            cost += 100*cost_m + cost_x

        return cost

# Construct dendrograms and initialize using annealing

In [None]:
"""
Read the images as numpy arrays and construct dendrograms from each image.
"""

if RECALCULATE:
    frames = []
    max_val = 0
    for frame_num in range(20):
        im = np.load('data/sim/z_tracer_{}.npy'.format(50+10*frame_num))
        im = cv2.resize(im, (300, 300))
        im = cv2.GaussianBlur(im, (3, 3), cv2.BORDER_CONSTANT)
        
        # # convert mass density to log-scale
        # im[im < 1e-3] = 1e-3
        # im = np.log10(im / 1e-3)
        
        frames.append(im)
        max_val = max(im.max(), max_val)
        
    dendro_frames = []
    for num, im in enumerate(frames):
        print('.', end='')
        im /= max_val
        d = Dendrogram(-im)
        dendro_frames.append(d)

        with open('data/dendros/{:02d}.pickle'.format(num), 'wb') as handle:
            pickle.dump(d, handle)

else:
    dendro_frames = []
    for frame_num in range(20): 
        with open('data/dendros/{:02d}.pickle'.format(frame_num), 'rb') as handle:
            dendro_frames.append(pickle.load(handle)) 

....................

In [None]:
"""
Construct the annealing links and run the annealing algorithm to
initialize linking matrices.
"""

if RECALCULATE:
    anneal_states = []
    for i in range(1, len(dendro_frames)):
        t0 = branch_to_tree(dendro_frames[i-1].root)
        t0.calc_desc()
        t1 = branch_to_tree(dendro_frames[i].root)
        t1.calc_desc()
        t0.link = t1
        t1.backlinks.add(t0)
        t0.switch(DendroMatch.rng)
        state = t0

        dmatch = DendroMatch(state)
        schedule = {
            "tmax":5000,
            "tmin":2,
            "steps":50000,
            "updates":200,
        }
        dmatch.set_schedule(schedule)
        dmatch.anneal()
        anneal_states.append(dmatch.state)

    with open('data/initialization.pickle', 'wb') as handle:
        pickle.dump(anneal_states, handle,
                    protocol=pickle.HIGHEST_PROTOCOL)

else:
    with open('data/initialization.pickle', 'rb') as handle:
        anneal_states = pickle.load(handle)


 Temperature        Energy    Accept   Improve     Elapsed   Remaining
     2.00000       5426.96    68.00%     0.00%     0:00:56     0:00:00
 Temperature        Energy    Accept   Improve     Elapsed   Remaining
     2.00000       1296.29    62.40%     0.00%     0:01:24     0:00:00
 Temperature        Energy    Accept   Improve     Elapsed   Remaining
     2.00000       2465.03    73.60%     0.00%     0:01:23     0:00:00
 Temperature        Energy    Accept   Improve     Elapsed   Remaining
     2.00000       3401.18    74.00%     0.80%     0:01:33     0:00:00
 Temperature        Energy    Accept   Improve     Elapsed   Remaining
     2.00000       1031.76    64.40%     0.40%     0:01:29     0:00:00
 Temperature        Energy    Accept   Improve     Elapsed   Remaining
     2.00000       5981.88    67.60%     0.00%     0:01:09     0:00:00
 Temperature        Energy    Accept   Improve     Elapsed   Remaining
     2.00000       7823.20    77.60%     0.00%     0:01:25     0:00:00
 Temp

# Continuous tracking

In [None]:
def cost(params, hyperparams, supplement):
    """
    GPU-optimized function for computing the cost of a set of linking
    matrices. Uses the negative log-likelihood as the cost function.
    Params:
        params (list of ndarrays): The time-series of linking matrices
        hyperparams (Hyperparams): hyperparameters
        supplement (tuple of Supplements): time-series of auxiliary 
            dendrogram data
    """

    cost_structure = 0
    cost_mass = 0
    cost_alignment = 0
    cost_src = 0
    
    hyp = hyperparams
    
    for t in range(len(params)):
        link_mat = softmax(params[t], axis=1)
        supp = supplement[t]
        mass_i, mass_j = supp.mass_i, supp.mass_j
        x_i, x_j = supp.x_i, supp.x_j
        cov_i, cov_j = supp.cov_i, supp.cov_j
        v_i_prior = supp.v_i_prior
        metric_j = supp.metric_j
        weights_j = supp.weights_j
        local_pairs = supp.local_pairs_i
        
        mass_mat_all = jnp.einsum('ij,i->ij', link_mat, mass_i)
        mass_mat = mass_mat_all[:-1, :-1]
        m_j = jnp.sum(mass_mat_all[:, :-1], axis=0)
        
        # dendrogram structure conservation
        dE_structure = 0
        for i1, i2, weight in local_pairs:
            m1, m2 = mass_i[i1], mass_i[i2]
            m1, m2 = m1/(m1+m2), m2/(m1+m2)
            var_in = 0
            if i1 != i2:
                var_in = m1 * m2
            # not necessarily normalized (we ignore sink!)
            out_distr = m1 * link_mat[i1, :-1] + m2 * link_mat[i2, :-1]
            joint_distr = jnp.outer(out_distr, out_distr) * metric_j**2
            var_out = 0.5 * jnp.sum(joint_distr)
            penalty = (var_out - var_in)**2 / (2 * hyp.struct_sigma**2)
            dE_structure += weight * penalty
        
        # mass conservation
        dm = mass_j/m_j
        chi = (dm - 1)**2 / dm
        penalties = chi**2 / (2 * hyp.mass_sigma**2)
        dE_mass = jnp.sum(penalties)
        
        # shape and location alignment
        dE_align = 0
        for j in range(len(x_j)):
            distr = mass_mat[:, j] / jnp.sum(mass_mat[:, j])
            mu_mix = jnp.einsum('i,ij->j', distr, (x_i + v_i_prior))
            mu_bar = jnp.einsum('i,ij->j', distr, x_i)
            d_mu = x_i - mu_bar
            outers = jnp.einsum('ij,ik->ijk', d_mu, d_mu)
            cov_mix = jnp.einsum('i,ijk->jk', distr, (cov_i + outers))
            
            mu_err = jnp.linalg.norm(mu_mix - x_j[j])
            cov_err = jnp.linalg.norm(cov_mix - cov_j[j])
            penalty = mu_err**2 / (2 * hyp.align_mu_sigma**2)
            penalty += cov_err**2 / (2 * hyp.align_cov_sigma**2)
            dE_align += penalty / len(x_j)

        # source/sink term
        dE_src = hyp.srcsink * (jnp.sum(mass_mat_all[-1, :-1]) \
                                + jnp.sum(mass_mat_all[:-1, -1]))

        cost_structure += dE_structure
        cost_mass += dE_mass
        cost_alignment += dE_align
        cost_src += dE_src

    return cost_structure + cost_mass + cost_alignment + cost_src

dcost = jit(grad(cost, argnums=0), static_argnums=(1,2))
cost = jit(cost, static_argnums=(1,2))

In [None]:
class Hyperparams:
    align_mu_sigma = 10
    align_cov_sigma = 1
    struct_sigma = 0.3
    mass_sigma = .01
    vel_mu = 30
    srcsink = 3000
    r_mid = 70  

def initialize_params(anneal_state):
    """
    Read annealing results into linking matrix, return linking matrix
    """
    param_mat = np.zeros((len(anneal_state.descendants)+1,
                          len(anneal_state.link.descendants)+1))
    param_mat[-1, -1] = 10    # source goes to sink
    for node in anneal_state.descendants:
        i, j = node.data.id, node.link.data.id
        param_mat[i, j] = 10
    return param_mat

def gen_supplement(d, d_next):
    class Supplement:
        mass_i = []
        mass_j = []
        x_i = []
        x_j = []
        cov_i = []
        cov_j = []
        v_i_prior = []
        metric_j = []
        weights_j = []
        local_pairs_i = []
    
    s = Supplement()
    s.mass_i = d.masses
    s.mass_j = d_next.masses[:-1]
    s.x_i = d.x
    s.x_j = d_next.x
    s.cov_i = np.array([b.covariance for b in d.branches])
    s.cov_j = np.array([b.covariance for b in d_next.branches])
    s.v_i_prior = [hyp.vel_mu * x / (np.linalg.norm(x)+1) for x in d.x]
    s.metric_j = d_next.metric
    s.weights_j = np.array([b.mass_frac for b in d_next.branches])
    
    local_pairs = []
    for i1, i2 in np.ndindex(*d.metric.shape):
        if (i1 <= i2) and d.metric[i1, i2] <= 1:
            weight = d.masses[i1] + d.masses[i2] / (np.sum(d.masses[:-1]))
            local_pairs.append((i1, i2, weight))
    s.local_pairs_i = local_pairs
    return s

hyp = Hyperparams()
supplement = []
params = []
for t in range(len(anneal_states)):
    param_frame = initialize_params(anneal_states[t])
    params.append(param_frame)
    d = dendro_frames[t]
    d_next = dendro_frames[t+1]
    supp = gen_supplement(d, d_next)
    supplement.append(supp)
supplement = tuple(supplement)

print(cost(params, hyp, supplement))

94215500000.0


In [None]:
"""
Perform gradient descent optimization
"""

opt_init, opt_update, get_params = optimizers.adam(step_size=1)

@jit
def update(i, opt_state):
    params = get_params(opt_state)
    grad = dcost(params, hyp, supplement)
    return opt_update(i, grad, opt_state)

opt_state = opt_init(params)

loss = []
n_epochs, n_steps = 1, 3000
for epoch in range(n_epochs):
    # dcost = jit(grad(cost, argnums=0), static_argnums=(1,2))
    # cost = jit(cost, static_argnums=(1,2))
    for step in range(n_steps):
        if step%100 == 0:
            print(".", end='')
            params = get_params(opt_state)
            loss.append(cost(params, hyp, supplement))
        opt_state = update(step, opt_state)
    print()
print()

params = get_params(opt_state)
print(cost(params, hyp, supplement))
print()

..............................

16152372.0



In [None]:
"""
Save optimized linking matrices
"""

with open('results/linkmats.pickle', 'wb') as handle:
    pickle.dump(params, handle, protocol=pickle.HIGHEST_PROTOCOL)