In [None]:
import itertools
import jax
import jax.numpy as jnp
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from skimage.transform import rescale

from basdi import *
from basdi.utils import frc
from basdi.typing import *

In [None]:
# Hyperparameters
track_velocity = True # whether the motion model assumes constant velocity

data_name = "Pos2" # file name prefix
block_size = 8000 # frames of each image block
binning = 200 # bin frames 
n_slices = 23 # number of z-slices 

k_err = 450 # scaling factor to convert intensity to loc err
z_err = 75  # z-axis loc err

th0 = [0, -300, 0] #  drift initialization value
s0 = 0.05 # initial scaling
sigma = [5, 5, 1] # lateral drift step size
sigma_b = [50, 50, 10] # step size at the block boundary
sigma_s = 0.0001 # scale step size

n_samples = 10000 # number of samples for computing Monte-Carlo integreation
key = jax.random.PRNGKey(4242) # seed value to RNG 
np.random.seed(4242)

In [None]:
all_data = pd.read_csv(f"{data_name}_allLocs.txt", sep="\t")

locs = all_data[["x", "y", "z"]].to_numpy()
intensity = all_data["Intensity"].to_numpy()
loc_err = jax.lax.rsqrt(intensity) * k_err
errs = np.repeat(loc_err, 2).reshape(-1,2)
errs = np.pad(errs, [[0,0],[0,1]], constant_values=z_err)

s = np.arange(n_slices+1).reshape(2,-1).transpose().reshape(-1)[:-1]
slice_seq = [f"slice{x}" for x in s]

frames = all_data["Frame"].to_numpy().astype(int)
slice_dic = dict(zip(slice_seq, range(len(slice_seq))))
u, inv = np.unique(all_data["Slice"].to_numpy(), return_inverse = True)
slice_idx = np.array([slice_dic[x] for x in u])[inv]
block = (frames -1) // block_size
block_res = ((frames - 1) % block_size) // binning
h_idx = (block * len(u) + slice_idx) * (block_size//binning) + block_res

In [None]:
# build a binned data sequence with padding

def data_gen():
    for i in range(h_idx.max()+1):
        sel = h_idx==i
        d0 = locs[sel]
        d1 = errs[sel]
        n = len(d0)
        padding = (len(d0)-1)//128*128+128 - n
        d0 = np.pad(d0, [[0,padding],[0,0]])
        d1 = np.pad(d1, [[0,padding],[0,0]])
        d2 = np.asarray((True,)*n + (False,)*padding, dtype=bool)
        yield d0, d1, d2

data = tuple(tqdm(data_gen()))

In [None]:
def block_transition(key, states, step=None, *, p=0.1, s=10):
    k1, k2 = jax.random.split(key)
    delta = jax.random.normal(k1, shape=states.shape) * jnp.asarray(sigma_b)
    selector = jax.random.uniform(k2, shape=delta.shape[:1] + (1,)) > p
    next_state = states + jax.numpy.where(selector, delta, delta * s)
    next_id = jnp.arange(len(next_state))
    return next_state, next_id

In [None]:
corrected = all_data.copy()
corrected["h_idx"] = h_idx
corrected = corrected.sort_values(by=["h_idx", "Frame"])
corrected = corrected.reset_index(drop=True)

## Model-1 : Drift only

In [None]:
locs = all_data[["x", "y", "z"]].to_numpy()
locs_c = locs

ps = 50
hm = LMModel(
    ps = (ps, ps, ps*4),
    norm_axis = (0, 1),
)

def transition_model(key, states, step):
    if step % 40 == 0:
        next_state_a, next_id = block_transition(key, states[:, :3])
        next_state_b = jnp.zeros_like(next_state_a)
        return jnp.c_[next_state_a, next_state_b], next_id
    else:
        v = states[:, 3:6] if track_velocity else 0
        next_state_a =  states[:, :3] + v
        next_state_a, next_id = gauss_transition_model(key, next_state_a, sigmas=sigma)
        next_state_b = next_state_a - states[:, :3]
        return jnp.c_[next_state_a, next_state_b], next_id


class CustomDrift(DriftModel):
    @classmethod
    def apply(cls, locs, states):
        locs_ = locs + states[:3]
        return locs_

    @classmethod
    def compute_likelihoods(cls, model, obs, states):
        locs, err, mask = obs
        locs_ = jax.vmap(cls.apply, in_axes=(None,0))(locs, states)
        p = model.e_log_ll(locs_, err, mask)

        return p

drift_model = CustomDrift

for m in [2,2,2,1,1]:
    hm.build(
        locs = locs_c,
        errs = errs * [m, m, 1],
        n_samples = 500,
    )

    init_states = np.zeros([n_samples, 6]) + (th0 + [0, 0, 0])
    history = smc_run_drift_inference(key, hm, drift_model, transition_model, tqdm(data), init_states)
    cm = smc_history_reduce(history)
    locs_c = smc_apply_drift(drift_model, data, cm)

    fig, ax = plt.subplots(2, 2, figsize=(20, 6))
    ax[1,0].plot(cm[:,2])
    ax[0,0].plot(cm[:,0])
    ax[0,1].plot(cm[:,1])
    plt.tight_layout()
    plt.show()

par = np.random.rand(len(locs_c)) >= 0.5
v1 = hm.render(locs_c[par]).sum(axis=-1)
v2 = hm.render(locs_c[~par]).sum(axis=-1)

frc_x, frc_y = frc(v1, v2)
plt.plot(frc_x / ps * 1000, frc_y)
plt.axis([0, 10, -0.2, 1.1])

In [None]:
np.save(f"{data_name}_n", cm)
corrected["xc"] = locs_c[:, 0]
corrected["yc"] = locs_c[:, 1]
corrected["zc"] = locs_c[:, 2]
corrected.to_csv(f"{data_name}_n.csv", index=False)

# Model-2 : Isotropic Expansion

In [None]:
locs = all_data[["x", "y", "z"]].to_numpy()
locs_c = locs

ps = 50
hm = LMModel(
    ps = (ps, ps, ps*4),
    norm_axis = (0, 1),
)

def transition_model(key, states, step):
    if step % 40 == 0:
        k1, k2 = jax.random.split(key)
        next_state_a, next_id = block_transition(k1, states[:, :3])
        next_state_b, _ = gauss_transition_model(k2, states[:, 3:7], sigmas=[.1, .1, .1, sigma_s])
        next_state_c = jnp.zeros_like(next_state_a)        
        next_state = jnp.c_[next_state_a, next_state_b, next_state_c]
        return next_state, next_id
    else:
        v = states[:, 7:10] if track_velocity else 0
        next_state_a = jnp.c_[states[:, :3] + v, states[:, 3:7]]
        next_state_a, next_id = gauss_transition_model(key, next_state_a, sigmas=sigma + [.1, .1, .1, sigma_s])
        next_state_b = next_state_a[:, :3] - states[:, :3]
        return jnp.c_[next_state_a, next_state_b], next_id

class CustomDrift(DriftModel):
    @classmethod
    def apply(cls, locs: ArrayLike, states:ArrayLike)->ArrayLike:
        _, dim = locs.shape

        scale = jnp.exp(states[..., 6])
        scale = jnp.stack([scale, scale, scale], axis=-1)

        exp_center = states[..., None, dim:dim*2]
        drift = states[..., None, :dim]

        locs_ = locs + drift
        locs_ = (locs_ - exp_center) * scale[..., None, :] + exp_center

        return locs_

    @classmethod
    def compute_likelihoods(cls, model: LMModel, obs: ArrayLike, states: ArrayLike)->Array:
        locs, err, mask = obs
        locs_ = cls.apply(locs, states)
        p = model.e_log_ll(locs_, err, mask)

        p += states[..., 6] * 2

        return p

drift_model= CustomDrift

for m in [2,2,2,1,1,1]:
    hm.build(
        locs = locs_c,
        errs = errs * [m, m, 1],
        n_samples = 500,
    )

    init_states = np.zeros([n_samples, 10])
    init_states[:, :3] = th0
    init_states[:, 3:6] = np.random.normal(size=[n_samples, 3]) * 500 + locs_c.mean(axis=0)
    init_states[:, 6] = np.random.normal(size=[n_samples]) * 0.01 + s0

    history = smc_run_drift_inference(key, hm, drift_model, transition_model, tqdm(data), init_states)
    cm = smc_history_reduce(history)
    locs_c = smc_apply_drift(drift_model, data, cm)

    fig, ax = plt.subplots(2, 2, figsize=(20, 6))
    ax[1,1].plot(np.exp(cm[:,6]))
    ax[1,0].plot(cm[:,2])
    ax[0,0].plot(cm[:,0])
    ax[0,1].plot(cm[:,1])
    plt.tight_layout()
    plt.show()

par = np.random.rand(len(locs_c)) >= 0.5
v1 = hm.render(locs_c[par]).sum(axis=-1)
v2 = hm.render(locs_c[~par]).sum(axis=-1)

frc_x, frc_y = frc(v1, v2)
plt.plot(frc_x / ps * 1000, frc_y)
plt.axis([0, 10, -0.2, 1])

In [None]:
np.save(f"{data_name}_i", cm)
corrected["xc"] = locs_c[:, 0]
corrected["yc"] = locs_c[:, 1]
corrected["zc"] = locs_c[:, 2]
corrected.to_csv(f"{data_name}_i.csv", index=False)

## Anisotropic expansion

In [None]:
locs = all_data[["x", "y", "z"]].to_numpy()
locs_c = locs

ps = 50
hm = LMModel(
    ps = (ps, ps, ps*4),
    norm_axis = (0, 1),
)

def transition_model(key, states, step):
    if step % 40 == 0:
        k1, k2 = jax.random.split(key)
        next_state_a, next_id = block_transition(k1, states[:, :3])
        next_state_b, _ = gauss_transition_model(k2, states[:, 3:8], sigmas=[.1, .1, .1, sigma_s, sigma_s])
        next_state_c = jnp.zeros_like(next_state_a)
        next_state = jnp.c_[next_state_a, next_state_b, next_state_c]
        return next_state, next_id
    else:
        v = states[:, 8:11] if track_velocity else 0
        next_state_a = jnp.c_[states[:, :3] + v, states[:, 3:8]]
        next_state_a, next_id = gauss_transition_model(key, next_state_a, sigmas=sigma + [.1, .1, .1, sigma_s, sigma_s])
        next_state_b = next_state_a[:, :3] - states[:, :3]
        return jnp.c_[next_state_a, next_state_b], next_id

class CustomDrift(DriftModel):
    @classmethod
    def apply(cls, locs: ArrayLike, states:ArrayLike)->ArrayLike:
        _, dim = locs.shape

        scale_xy = jnp.exp(states[..., 6])
        scale_z = jnp.exp(states[..., 7])
        scale = jnp.stack([scale_xy, scale_xy, scale_z], axis=-1)

        exp_center = states[..., None, dim:dim*2]
        drift = states[..., None, :dim]

        locs_ = locs + drift
        locs_ = (locs_ - exp_center) * scale[..., None, :] + exp_center

        return locs_

    @classmethod
    def compute_likelihoods(cls, model: LMModel, obs: ArrayLike, states: ArrayLike)->Array:
        locs, err, mask = obs
        locs_ = cls.apply(locs, states)
        p = model.e_log_ll(locs_, err, mask)

        p += states[..., 6] * 2

        return p

drift_model= CustomDrift

for m in [2,2,2,1,1,1]:
    hm.build(
        locs = locs_c,
        errs = errs * [m, m, 1],
        n_samples = 500,
    )

    init_states = np.zeros([n_samples, 11])
    init_states[:, :3] = th0
    init_states[:, 3:6] = np.random.normal(size=[n_samples, 3]) * 500 + locs_c.mean(axis=0)
    init_states[:, 6:8] = np.random.normal(size=[n_samples, 2]) * 0.01 + s0

    history = smc_run_drift_inference(key, hm, drift_model, transition_model, tqdm(data), init_states)
    cm = smc_history_reduce(history)
    locs_c = smc_apply_drift(drift_model, data, cm)

    fig, ax = plt.subplots(2, 2, figsize=(20, 6))
    ax[1,1].plot(np.exp(cm[:,6:8]))
    ax[1,0].plot(cm[:,2])
    ax[0,0].plot(cm[:,0])
    ax[0,1].plot(cm[:,1])
    plt.tight_layout()
    plt.show()

par = np.random.rand(len(locs_c)) >= 0.5
v1 = hm.render(locs_c[par]).sum(axis=-1)
v2 = hm.render(locs_c[~par]).sum(axis=-1)

frc_x, frc_y = frc(v1, v2)
plt.plot(frc_x / ps * 1000, frc_y)
plt.axis([0, 10, -0.2, 1])

In [None]:
np.save(f"{data_name}_a", cm)
corrected["xc"] = locs_c[:, 0]
corrected["yc"] = locs_c[:, 1]
corrected["zc"] = locs_c[:, 2]
corrected.to_csv(f"{data_name}_a.csv", index=False)

## Comparison

In [None]:
for fn in ["n", "i", "a"]:
    df = pd.read_csv(f"{data_name}_{fn}.csv")
    locs_c = df[["xc", "yc", "zc"]].to_numpy()
    par = np.random.rand(len(locs_c)) >= 0.5
    v1 = hm.render(locs_c[par]).sum(axis=-1)
    v2 = hm.render(locs_c[~par]).sum(axis=-1)

    frc_x, frc_y = frc(v1, v2)
    plt.plot(frc_x / ps * 1000, frc_y, label=fn.split(".")[0])

plt.axis([0, 10, -0.2, 1])
plt.legend()