In [1]:
import matplotlib.pyplot as plt
import matplotlib._color_data as mcd
import matplotlib.colors as mplc

import scipy.optimize
import jax
from jax.nn import softmax
from jax import grad, jit
import jax.numpy as jnp
from jax.experimental import optimizers
import time
from dendrogram import *

import numpy as np
import matplotlib.pyplot as plt
import cv2

In [34]:
fn = 3
dendro_frames = []
for frame_num in range(fn, fn+10):
#     orig = cv2.imread("test_images/tracer/{}.png".format(frame_num),
#                       cv2.IMREAD_GRAYSCALE)
#     orig = orig[50:825, 125:900]
#     orig = cv2.resize(orig, (300, 300))
    orig = cv2.imread("gen/{:02d}.png".format(frame_num), cv2.IMREAD_GRAYSCALE).astype(np.float32)
    im = cv2.GaussianBlur(orig, (3, 3), cv2.BORDER_CONSTANT)
    d = make_dendrogram(-im)
    dendro_frames.append(d)

In [26]:
def cost(params, supplement, coefficients):        

    cost_inertia = 0
    cost_momentum = 0
    cost_vdiv = 0
    cost_locality = 0
    c0, c1, c2, c3 = coefficients
    
    prev_p_j = None
    for t in range(len(params)):
        link_mat = softmax(params[t][:, :, 0], axis=0)
        v_mat = params[t][:, :, 1:]
        weights, masses, x_cur, x_nxt, locality_mat = supplement[t]
        
        mass_mat = jnp.einsum('ij,i->ij', link_mat, masses)
        p_mat = jnp.einsum('ij,ijk->ijk', mass_mat, v_mat)
        p_j = jnp.sum(p_mat, axis=0)
        m_j = jnp.sum(mass_mat, axis=0)
        
        dE_inertia = 0
        com_j = jnp.einsum('ij,ik,j->jk', mass_mat, x_cur, 1/m_j)
        v_j = jnp.einsum('ij,i->ij', p_j, 1/m_j)
        overshoot = v_j - (x_nxt - com_j)
        penalty = jnp.linalg.norm(overshoot, axis=-1)
        dE_inertia = c0 * jnp.sum(penalty)
        
        dE_momentum = 0
        if t >= 1:
            p_i = jnp.sum(p_mat, axis=1)
            momentum_lost = p_i - prev_p_j
            penalty = weights * jnp.linalg.norm(momentum_lost, axis=-1)
            dE_momentum = c1 * jnp.sum(penalty)
        prev_p_j = p_j
        
        dE_v_div = 0
        for i in range(v_mat.shape[0]):
            v_div = jnp.trace(jnp.cov(v_mat[i, :, :].T))
            dE_v_div += c2 * v_div
        
        dE_locality = 0
        for i in range(link_mat.shape[0]):
            outer_prod = jnp.outer(link_mat[i, :], link_mat[i, :])
            dE_locality += c3 * jnp.mean(outer_prod * locality_mat)
            
        cost_inertia += dE_inertia
        cost_momentum += dE_momentum
        cost_vdiv += dE_v_div
        cost_locality += dE_locality
        
    return cost_inertia + cost_momentum + cost_vdiv + cost_locality

dcost = jit(grad(cost, argnums=0), static_argnums=(1,2))
cost = jit(cost, static_argnums=(1,2))

In [39]:
def initialize_params(dendro_frames, thresh):
    params = []
    for t in range(len(dendro_frames)-1):
        cur_sz = dendro_frames[t].N
        next_sz = dendro_frames[t+1].N
        param_mat = np.zeros((cur_sz, next_sz, 3))
        for i, j in np.ndindex(cur_sz, next_sz):
            loc = dendro_frames[t].branches[i].loc
            vel = 10 * loc / np.linalg.norm(loc)
            param_mat[i, j, 1:] = vel
            sep = dendro_frames[t+1].branches[j].loc - dendro_frames[t].branches[i].loc
            if np.linalg.norm(sep) < thresh:
                param_mat[i, j, 0] = 1
        params.append(param_mat)
    return params

params = initialize_params(dendro_frames, 20)
locs = []
masses = []
wts = []
locality = []
supplement = []
for t in range(len(dendro_frames)-1):
    d = dendro_frames[t]
    d_nxt = dendro_frames[t+1]
    weights = np.array([b.weight for b in d.branches])
    masses = np.array([b.exclusive_mass for b in d.branches])
    x_cur = np.array([b.loc for b in d.branches])
    x_nxt = np.array([b.loc for b in d_nxt.branches])
    locality = d_nxt.locality_mat
    supp = [weights, masses, x_cur, x_nxt, locality]
    supplement.append(supp)
coefficients = [100, 0.1, 50000, 100000]

print(cost(params, supplement, coefficients))

opt_init, opt_update, get_params = optimizers.adam(step_size=.01)

@jit
def update(i, opt_state):
    params = get_params(opt_state)
    return opt_update(i, dcost(params, supplement, coefficients), opt_state)

opt_state = opt_init(params)

for step in range(700):
    opt_state = update(step, opt_state)

params = get_params(opt_state)
print(cost(params, supplement, coefficients))
print()

628975.6
6981.9062



In [42]:
import importlib
import dendrogram
importlib.reload(dendrogram)
from dendrogram import *

def assign_trajectories(dendro_frames, params):
    traj_counter = 0
    thresh = 0.9
    thresh2 = 0.5
    for b in dendro_frames[0].branches:
        if b.mass_frac > thresh:
            b.traj_id = traj_counter
            traj_counter += 1
    for t in range(1, len(dendro_frames)):
        p = params[t-1]
        d_prev = dendro_frames[t-1]
        d = dendro_frames[t]
        
        mass_in = np.array([b.exclusive_mass for b in d_prev.branches])
        mass_mat = np.einsum('ij,i->ij', softmax(p[:, :, 0], axis=0), mass_in)
        
        for branch_id in range(d.N):
            branch = d.branches[branch_id]
            if branch.mass_frac < thresh or branch.traj_id is not None:
                continue
            mass_brkdwn = np.einsum('ij,j->i', mass_mat, d.hierarchy[branch_id])
            mass_in = np.sum(mass_brkdwn)
            
            src_branch = d_prev.branches[np.argmax(mass_brkdwn)]
            best_bullets = None
            while src_branch:
                twigs = src_branch.list_descendants()
                mass_contr = 0
                bullets = np.zeros(d_prev.N)
                for t in twigs:
                    mass_contr += mass_brkdwn[t.id]
                    if t.mass_frac > thresh:
                        bullets[t.id] = 1
                if mass_contr / mass_in > thresh2:
                    best_bullets = bullets
                src_branch = src_branch.parent
            
            if best_bullets is not None:
                ind = np.argmax(mass_brkdwn * best_bullets)
                largest_donor = d_prev.branches[ind]
                assert largest_donor.traj_id is not None
                traj = largest_donor.traj_id
            else:
                traj = traj_counter
                traj_counter += 1
            for twig in branch.list_descendants():
                twig.traj_id = traj

for d in dendro_frames:
    for b in d.branches:
        b.traj_id = None
assign_trajectories(dendro_frames, params)
                

In [45]:
def draw_traj(dendro_frames):
    colors = ["black", "aqua", "coral", "chartreuse", "azure",
              "beige", "goldenrod", "brown", "lavender",
              "fuchsia", "silver", "ivory", "yellow"]
    colors = [mcd.XKCD_COLORS["xkcd:" + c] for c in colors]
    colors = [255*np.array(mplc.to_rgb(c)) for c in colors]
    
    for t in range(len(dendro_frames)):
        im = np.zeros((200, 200, 3))
        for b in dendro_frames[t].branches:
            if b.mass_frac > 0.9:
                for x, y in b.full_region:
                    im[x, y, :] = colors[1 + b.traj_id] # (b.traj_id+1)/13.
        cv2.imwrite("results/traj-attempt5/traj-{:02d}.png".format(t), im)
        
draw_traj(dendro_frames)

In [None]:
for t in range(len(params)):
    p = params[t]
    link_mat = np.round(softmax(p[0, :, :], axis=1)*100).astype(int)
    print(link_mat)
    print(dendros[t+1].locality_mat)
    v_mat = p[1:, :, :]
    print("FROM")
    count = 0
    for i in range(len(dendros[t].branches)):
        b = dendros[t].branches[i]
        v = np.round(np.mean(v_mat[:, i, :], axis=1))
        dv = np.round(np.trace(np.cov(v_mat[:, i, :]))*100)
        if b.mass_frac <= 0.9:
            link_mat = np.delete(link_mat, b.tree_id-count, axis=0)
            count += 1
        print(np.round(b.mass_frac, 2), v, b.vol, b.traj_id)
    print()
    print("TO")
    count = 0
    for b in dendros[t+1].branches:
        if b.mass_frac <= 0.9:
            link_mat = np.delete(link_mat, b.tree_id-count, axis=1)
            count += 1
        print(np.round(b.mass_frac, 2), b.traj_id)
    print(link_mat)
    print()
    print()