In [16]:
import equinox as eqx
import jax.numpy as jnp
import jax

In [66]:
class BirdFlowMarkovChain(eqx.Module):
    cells:  list # the number of grid cells in each S&T week
    params: list # the model parameters (initial distribution, transition matrices)
    
    def __init__(self, key, cells):
        self.cells = cells
        self.params = []
        key, subkey = jax.random.split(key)
        self.params.append(jax.random.normal(subkey, shape=(cells[0]))) # initial distribution params
        for t in range(len(cells)-1):
            key, subkey = jax.random.split(key)
            self.params.append(jax.random.normal(subkey, shape=(cells[t+1], cells[t]))) # transition matrix params
    
    # compute single tstep / pairwise marginals
    # see appendix C of BirdFlow paper: https://www.biorxiv.org/content/10.1101/2022.04.12.488057v1.full.pdf+html
    def __call__(self):
        single_tstep_marginals = []
        pairwise_marginals = []
        mu_1 = jax.nn.softmax(self.params[0]) # the first single timestep marginal
        single_tstep_marginals.append(mu_1)
        softmax_mat = jax.vmap(jax.nn.softmax)
        
        mu_t = mu_1
        for t in range(1, len(self.params)):
            T = jax.nn.softmax(self.params[t]) # normalizes rows of the array
            pairwise_t = mu_t * T # probabilities of transitions from week t-1 -> week t
            mu_t = jnp.sum(pairwise_t, axis=1) # probabilities of bird's location in week t
            single_tstep_marginals.append(mu_t)
            pairwise_marginals.append(pairwise_t)
        
        return single_tstep_marginals, pairwise_marginals # convert to jax arrays

In [84]:
def l2_location_loss(single_tstep_marginals, st_marginals):
    total = 0
    T = len(single_tstep_marginals)
    for mu_t, mu_hat_t in zip(single_tstep_marginals, st_marginals):
        n_t = mu_t.shape[0]
        total += jnp.sum((mu_t - mu_hat_t) ** 2) / n_t   # intuition - average squared diff. between model / status and trends probabilities for week t
    return total / T   # return average of the average squared differences for each week (divide by T)

def w2_location_loss(single_tstep_marginals, st_marginals):
    pass

def distance_loss(pairwise_marginals, dists):
    return 0

def entropy_loss(single_tstep_marginals, pairwise_marginals):
    return 0

@eqx.filter_jit
def loss(model, st_marginals, alpha):
    single_tstep_marginals, pairwise_marginals = model()
    return alpha[0] * l2_location_loss(single_tstep_marginals, st_marginals) + \
           alpha[1] * distance_loss(pairwise_marginals, None) + \
           alpha[2] * entropy_loss(single_tstep_marginals, pairwise_marginals)
            

In [83]:
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
model = BirdFlowMarkovChain(subkey, [2, 3])
key, subkey = jax.random.split(key)
st_marginals = [jnp.array([0.1, 0.9]), jnp.array([0.2, 0.2, 0.6])]

print(loss(model, st_marginals, [0.1, 0.1, 0.1]))

0.014581856


In [36]:
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
softmax_mat = jax.vmap(jax.nn.softmax)
a = jax.random.normal(subkey, shape=(2, 2))
print(jax.jacobian(jax.nn.softmax)(a))

[[[[ 0.15359752 -0.15359752]
   [ 0.          0.        ]]

  [[-0.15359752  0.15359753]
   [ 0.          0.        ]]]


 [[[ 0.          0.        ]
   [ 0.07649832 -0.07649833]]

  [[ 0.          0.        ]
   [-0.07649833  0.07649827]]]]


In [40]:
a = jnp.array([1, 2])
b = jnp.array([[1, 2],[2, 3], [3, 4]])
c = a * b
print(c)
jnp.sum(c, axis=1)

[[1 4]
 [2 6]
 [3 8]]


Array([ 5,  8, 11], dtype=int32)