In [1]:
import matplotlib.pyplot as plt
import matplotlib._color_data as mcd
import matplotlib.colors as mplc

from jax.nn import softmax
from jax import grad, jit
import jax.numpy as jnp
from jax.experimental import optimizers

import numpy as np
import cv2

In [4]:
import importlib
import dendrogram
importlib.reload(dendrogram)
from dendrogram import *

In [5]:
fn = 3
dendro_frames = []
for frame_num in range(fn, fn+10):
#     orig = cv2.imread("test_images/tracer/{}.png".format(frame_num),
#                       cv2.IMREAD_GRAYSCALE)
#     orig = orig[50:825, 125:900]
#     orig = cv2.resize(orig, (300, 300))
    orig = cv2.imread("gen/{:02d}.png".format(frame_num), cv2.IMREAD_GRAYSCALE).astype(np.float32)
    im = cv2.GaussianBlur(orig, (3, 3), cv2.BORDER_CONSTANT)
    d = make_dendrogram(-im)
    dendro_frames.append(d)

In [72]:
def cost(params, supplement, hyperparams):        

    cost_inertia = 0
    cost_momentum = 0
    cost_vdiv = 0
    cost_locality = 0
    cost_mass = 0
    cost_src_sink = 0
    c0, c1, c2, c3, c4, c5, r_mid, prox_s = hyperparams
    
    prev_p_j = None
    for t in range(len(params)):
        link_mat = softmax(params[t][:-1, :, 0], axis=1)
        v_mat = params[t][:-1, :-1, 1:]
        weights, mass_cur, mass_nxt, x_cur, x_nxt, locality_mat = supplement[t]
        
        mass_mat_all = jnp.einsum('ij,i->ij', link_mat, mass_cur)
        mass_mat = mass_mat_all[:, :-1]
        p_mat = jnp.einsum('ij,ijk->ijk', mass_mat, v_mat)
        p_j = jnp.sum(p_mat, axis=0)
        m_j_src = jnp.abs(params[t][-1, :-1, 0])
        m_j = jnp.sum(mass_mat, axis=0) + m_j_src
        
        r = jnp.linalg.norm(x_nxt, axis=1)
        proximity = 1/(np.exp((r-r_mid)/prox_s) + 1)
        
        dE_inertia = 0
        com_j = jnp.einsum('ij,ik,j->jk', mass_mat, x_cur, 1/m_j)
        dx_j = x_nxt - com_j
        overshoot = p_j - jnp.einsum('j,jk->jk', m_j, dx_j)
        penalty = jnp.linalg.norm(overshoot, axis=-1)
        dE_inertia = c0 * jnp.sum(penalty)
        
        dE_momentum = 0
        if t >= 1:
            p_i = jnp.sum(p_mat, axis=1)
            momentum_lost = p_i - prev_p_j
            penalty = weights * jnp.linalg.norm(momentum_lost, axis=-1)
            dE_momentum = c1 * jnp.sum(penalty)
        prev_p_j = p_j
        
        dE_v_div = 0
        for i in range(v_mat.shape[0]):
            v_div = jnp.trace(jnp.cov(v_mat[i, :, :].T))
            dE_v_div += c2 * v_div
        
        dE_locality = 0
        for i in range(link_mat.shape[0]-1):
            outer_prod = jnp.outer(link_mat[i, :-1], link_mat[i, :-1])
            dE_locality += c3 * jnp.mean(outer_prod * locality_mat)
            
        mass_gain = mass_nxt/m_j - 1
        modulator = r / (r + r_mid)
        penalty = modulator * mass_gain**2
        dE_mass = c4 * jnp.sum(penalty)
        
        penalty = c5 * (1-proximity) * m_j_src
        dE_src_sink = jnp.sum(penalty)
        dE_src_sink += c5 * jnp.sum(mass_mat_all[:, -1])
            
        cost_inertia += dE_inertia
        cost_momentum += dE_momentum
        cost_vdiv += dE_v_div
        cost_locality += dE_locality
        cost_mass += dE_mass
        cost_src_sink += dE_src_sink
        
#     print(cost_inertia, cost_momentum, cost_vdiv, cost_locality, cost_mass, cost_src_sink)
    return cost_inertia + cost_momentum + cost_vdiv + cost_locality + cost_mass + cost_src_sink

dcost = jit(grad(cost, argnums=0), static_argnums=(1,2))
cost = jit(cost, static_argnums=(1,2))

In [71]:
def logistic(x, mu=0, s=1):
    return 1/(np.exp((x-mu)/s) + 1)

def initialize_params(d, d_next):
    param_mat = np.zeros((d.N+1, d_next.N+1, 3))
    for i, j in np.ndindex(d.N, d_next.N):
        loc = d.branches[i].loc
        vel = 10 * loc / np.linalg.norm(loc)
        param_mat[i, j, 1:] = vel
        sep = np.linalg.norm(d_next.branches[j].loc - d.branches[i].loc)
        param_mat[i, j, 0] = logistic(sep, mu=20, s=5)
    # TODO: add src and sink terms
    return param_mat

supplement = []
params = []
for t in range(len(dendro_frames)-1):
    d = dendro_frames[t]
    d_nxt = dendro_frames[t+1]
    param_frame = initialize_params(d, d_nxt)
    params.append(param_frame)
    
    weights = np.array([b.weight for b in d.branches])
    mass_cur = np.array([b.exclusive_mass for b in d.branches])
    mass_nxt = np.array([b.exclusive_mass for b in d_nxt.branches])
    x_cur = np.array([b.loc for b in d.branches])
    x_nxt = np.array([b.loc for b in d_nxt.branches])
    locality = d_nxt.locality_mat
    supp = [weights, mass_cur, mass_nxt, x_cur, x_nxt, locality]
    supplement.append(supp)
    
hyperparams = [.01, 0.1, 50000, 100000, 100, 1000, 70, 5]

print(cost(params, supplement, hyperparams))

opt_init, opt_update, get_params = optimizers.adam(step_size=.01)

@jit
def update(i, opt_state):
    params = get_params(opt_state)
    return opt_update(i, dcost(params, supplement, hyperparams), opt_state)

opt_state = opt_init(params)

for step in range(7000):
    opt_state = update(step, opt_state)

params = get_params(opt_state)
print(cost(params, supplement, hyperparams))
print()

46564940.0
16714.982



In [73]:
print(cost(params, supplement, hyperparams))

7415.9736 1872.1123 8.532631 1985.016 1759.809 3673.5396
16714.982


In [74]:
def assign_trajectories(dendro_frames, params):
    traj_counter = 0
    thresh = 0.9
    thresh2 = 0.25
    for b in dendro_frames[0].branches:
        if b.mass_frac > thresh:
            b.traj_id = traj_counter
            traj_counter += 1
    for t in range(len(dendro_frames)-1):
        p = params[t]
        d = dendro_frames[t]
        d_nxt = dendro_frames[t+1]
        
        mass_in = np.array([b.exclusive_mass for b in d.branches])
        mass_mat = np.zeros((d.N+1, d_nxt.N))
        mass_mat[:-1, :] = np.einsum('ij,i->ij', softmax(p[:-1, :-1, 0], axis=1), mass_in)
        mass_mat[-1, :] = p[-1, :-1, 0]
        
        for branch_id in range(d_nxt.N):
            branch = d_nxt.branches[branch_id]
            if branch.mass_frac < thresh or branch.traj_id is not None:
                continue
            mass_breakdown = np.einsum('ij,j->i', mass_mat, d_nxt.hierarchy[branch_id])
            mass_breakdown = mass_breakdown / np.sum(mass_breakdown)
            sorted_inds = np.flip(np.argsort(mass_breakdown))
            traj_id = None
            for ind, m_frac in zip(sorted_inds, mass_breakdown):
                if ind == d.N:
                    if m_frac > thresh2:
                        break
                elif d.branches[ind].traj_id is not None:
                    traj_id = d.branches[ind].traj_id
                    break
            if traj_id == None:
                traj_id = traj_counter
                traj_counter += 1
            
            for twig in branch.list_descendants():
                twig.traj_id = traj_id
                
for d in dendro_frames:
    for b in d.branches:
        b.traj_id = None
assign_trajectories(dendro_frames, params)


In [75]:
def draw_traj(dendro_frames):
    colors = ["aqua", "coral", "chartreuse", "azure",
              "beige", "goldenrod", "brown", "lavender",
              "fuchsia", "silver", "ivory", "yellow"]
#     colors = [mcd.XKCD_COLORS["xkcd:" + c] for c in colors]
    colors = [255*np.array(mplc.to_rgb("xkcd:" + c)) for c in colors]
    
    for t in range(len(dendro_frames)):
        im = np.zeros((200, 200, 3))
        for b in dendro_frames[t].branches:
            if b.mass_frac > 0.9:
                for x, y in b.full_region:
                    im[x, y, :] = colors[b.traj_id % len(colors)]
        cv2.imwrite("results/traj-attempt5/traj-{:02d}.png".format(t), im)
        
draw_traj(dendro_frames)

In [91]:
np.set_printoptions(suppress=True)
for t in range(len(params)):
    p = params[t]
    d = dendro_frames[t]
    d_nxt = dendro_frames[t+1]
#     link_mat = np.round(softmax(p[:, :, 0], axis=1)*100).astype(int)
    mass_in = np.array([b.exclusive_mass for b in d.branches])
    mass_mat = np.zeros((d.N+1, d_nxt.N+1))
    mass_mat[:-1, :] = np.einsum('ij,i->ij', softmax(p[:-1, :, 0], axis=1), mass_in)
    mass_mat[-1, :] = p[-1, :, 0]
    mass_frac_mat = np.einsum('ij,j->ij', mass_mat, 1/(np.sum(mass_mat, axis=0))) * 100
    mass_frac_mat = np.round(mass_frac_mat)
    print(mass_mat.round())
#     print(d_nxt.locality_mat)
    v_mat = p[:, :, 1:]
#     print("FROM")
    count = 0
    for b in d.branches:
        if b.mass_frac <= 0.9:
            mass_mat = np.delete(mass_mat, b.id-count, axis=0)
            count += 1
        else:
            print(b.loc[0].round(), b.loc[1].round(), b.traj_id)
    print()
#     print("TO")
    count = 0
    for b in d_nxt.branches:
        if b.mass_frac <= 0.9:
            mass_mat = np.delete(mass_mat, b.id-count, axis=1)
            count += 1
        else:
            print(b.loc[0].round(), b.loc[1].round(), b.traj_id)
    print(mass_mat.round())
    print()
    print('====================================')

[[   5.    0.    0.    0.    0.]
 [ 183. 3012.    3.    2.    0.]
 [1727.    9.  513.   11.    0.]
 [ 539.  154.  373. 6331.    0.]
 [   0.  -27.   81.    0.    0.]]
-1.0 12.0 0
-0.0 -2.0 1
-3.0 -29.0 2

-2.0 21.0 0
-3.0 -12.0 1
-3.0 -37.0 2
[[3012.    3.    2.    0.]
 [   9.  513.   11.    0.]
 [ 154.  373. 6331.    0.]
 [ -27.   81.    0.    0.]]

[[    2.     0.     0.     0.     0.     0.]
 [  450.  4088.    67.     8.     4.     0.]
 [  942.     2.    11.  2300.     4.     0.]
 [  151.    20.   104.   405. 10004.     0.]
 [    0.    -0.    72.    52.     0.     0.]]
-2.0 21.0 0
-3.0 -12.0 1
-3.0 -37.0 2

-3.0 29.0 0
0.0 2.0 2
-5.0 -21.0 1
-4.0 -44.0 2
[[ 4088.    67.     8.     4.     0.]
 [    2.    11.  2300.     4.     0.]
 [   20.   104.   405. 10004.     0.]
 [   -0.    72.    52.     0.     0.]]

[[    0.     0.     8.     0.     0.     0.     0.]
 [  510.  6130.    76.     6.     6.     3.     0.]
 [ 1083.     3.   975.     2.     1.     1.     0.]
 [    5.     2.     3.   