In [1]:
# Goals (6/30):
# - understand specfics of algorithms for computing single-timestep and pairwise marginals from a markov chain
# - go through the code closely, figure out how the calculation of these marginals is implemented in model_forward (using haiku)



In [18]:
# important haiku functions

"""
hk.transform(f(args)) => (init(args) => params) and (apply(params, args) => output) functions

- see https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.transform

hk.get_parameter
- either initializes or reuses a parameter for a given transformed function. depends on whether we are calling the 
init or the apply
"""

'\nhk.transform(f(args)) => (init(args) => params) and (apply(params, args) => output) functions\n\n- see https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.transform\n\n- \n'

In [52]:
# the model
import haiku as hk

from jax.nn import softmax
import jax.numpy as jnp
from jax.random import categorical


class InitialLoc(hk.Module):
    def __init__(self, cells):
        super().__init__(name='Initial_Params')
        self.cells = cells
        
    
    def __call__(self):
        z0 = hk.get_parameter(
            'z0',
            (self.cells,),
            #init=jnp.zeros,
            init=hk.initializers.RandomNormal(),
            dtype = 'float32'
        )
        return softmax(z0)


class FlowBlock(hk.Module):
    def __init__(self, cells1, cells2, week_num=None):
        if week_num:
            name = f'Week_{week_num}'
        else:
            name = 'transition_block'
        super().__init__(name=name)
        self.cells1 = cells1
        self.cells2 = cells2
        
        
    def __call__(self, last_week):
        z = hk.get_parameter(
            'z',
            (self.cells1, self.cells2),
            #init=jnp.zeros,
            init=hk.initializers.RandomNormal(),
            dtype = 'float32'
        )
        
        trans_prop = softmax(z, axis=1)
        flow = trans_prop * last_week.reshape(-1, 1)
        return flow
    

class FlowModel(hk.Module):
    def __init__(self, cells, num_weeks, name='Flow_Model'):
        super().__init__(name=name)
        self.num_weeks = num_weeks
        self.cells = cells
        
        
    def __call__(self):
        d0 = InitialLoc(self.cells[0])()
        d = d0
        flow_amounts = []
        for week in range(self.num_weeks - 1):
            flow = FlowBlock(self.cells[week], self.cells[week + 1], week_num=week + 1)(d)
            flow_amounts.append(flow)
            d = flow.sum(axis=0)
        return (d0, flow_amounts)

def predict(cells, weeks):
    model = FlowModel(cells, weeks)
    return model()

model_forward = hk.transform(predict)

In [53]:
key = hk.PRNGSequence(17)
params = model_forward.init(next(key), [2]*5, 5)
print(params)

{'Flow_Model/Initial_Params': {'z0': Array([-0.5936502 , -0.03086448], dtype=float32)}, 'Flow_Model/Week_1': {'z': Array([[-0.3210647 , -0.42314035],
       [-0.40206048,  0.05655674]], dtype=float32)}, 'Flow_Model/Week_2': {'z': Array([[ 0.28132653,  0.35250533],
       [-0.60869765,  0.18230876]], dtype=float32)}, 'Flow_Model/Week_3': {'z': Array([[-0.49886376, -4.3779397 ],
       [ 1.6097612 , -1.3355856 ]], dtype=float32)}, 'Flow_Model/Week_4': {'z': Array([[ 0.4312313 , -0.9348354 ],
       [-0.92259514,  0.61290187]], dtype=float32)}}


In [41]:
initial, pairwise_marginals = model_forward.apply(params, None, [2]*5, 5)
print(pairwise_marginals[2].sum(axis=0))

[0.96152234 0.03847763]


In [93]:
# mixture of products model
import haiku as hk
from jax.nn import softmax
import jax.numpy as jnp
from jax.random import categorical

class Product(hk.Module):
    def __init__(self, cells, idx):
        super().__init__(name=f"Product{idx}")
        self.cells = cells
    
    def __call__(self, t):
        weekly_marginal = hk.get_parameter(
            f'week_{t}',
            (self.cells[t],),
            init=hk.initializers.RandomNormal(),
            dtype='float32'
        )
        
        return softmax(weekly_marginal, axis=0)


class MixtureOfProductsModel(hk.Module):
    def __init__(self, cells, weeks, n, name="MixtureOfProductsModel"):
        super().__init__(name=name)
        self.weeks = weeks
        self.cells = cells
        self.n = n
    
    def get_marginal(self, products, weights, tsteps):
        marginal = 0 
        for k in range(self.n):
            prod_k_marginal = jnp.asarray(1)
            for tstep in tsteps:
                prod_k_marginal = jnp.tensordot(prod_k_marginal, products[k](tstep), axes=0)
            marginal += weights[k] * prod_k_marginal
        return marginal
        
    def __call__(self):
        # initialize weights
        weights = hk.get_parameter(
            'weights',
            (self.n,),
            init=hk.initializers.RandomNormal(),
            dtype='float32'
        )
        weights = softmax(weights, axis=0)
        
        # initialize product distributions
        products = []
        for k in range(self.n):
            products.append(Product(self.cells, k))
        
        single_tstep_marginals = [self.get_marginal(products, weights, [t]) for t in range(self.weeks)]
        pairwise_marginals = [self.get_marginal(products, weights, [t, t+1]) for t in range(self.weeks-1)]
        
        return single_tstep_marginals, pairwise_marginals

def predict(cells, weeks, n):
    model = MixtureOfProductsModel(cells, weeks, n)
    return model()


model_forward = hk.transform(predict)

In [94]:
# loss function for mixture of products
N_PRODUCTS = 10
def loss_fn(params, cells, true_densities, d_matrices, obs_weight, dist_weight, ent_weight):
    weeks = len(true_densities)
    pred = model_forward.apply(params, None, cells, weeks, N_PRODUCTS)
    pred_densities, flows = pred
    obs = obs_loss(pred_densities, true_densities)
    dist = distance_loss(flows, d_matrices)
    ent = ent_loss(flows, pred_densities)
    
    return (obs_weight * obs) + (dist_weight * dist) + (-1 * ent_weight * ent), (obs, dist, ent)

def obs_loss(pred_densities, true_densities):
    obs = 0
    for pred, true in zip(pred_densities, true_densities):
        residual = true - pred
        obs += jnp.sum(jnp.square(residual))
    return obs

def distance_loss(flows, d_matrices):
    dist = 0
    for flow, d_matrix in zip(flows, d_matrices):
        dist += jnp.sum(flow * d_matrix)
    return dist

def entropy(probs):
    logp = jnp.log(probs)
    ent = probs * logp
    h = -1 * jnp.sum(ent)
    return h

def ent_loss(probs, flows):
    ent = 0
    for p in probs:
        ent += entropy(p)
    for f in flows:
        ent -= entropy(f)
    return ent



In [103]:
# see if we can compute the loss function! (we can!!)

import os
import h5py
import numpy as np
import sys
sys.path.insert(1, '/Users/jacobepstein/Documents/work/BirdFlowPy/')
from flow_model_training import Datatuple, mask_input

# load true_densities from ebird_st
obs_weight = 1
ent_weight = 1e-4
dist_weight = 1e-2
dist_pow = 0.4
hdf_src = os.path.join("/Users/jacobepstein/Documents/work/BirdFlowModels", "amewoo_2021_48km.hdf5")
file = h5py.File(hdf_src, 'r')
true_densities = np.asarray(file['distr']).T

# get weeks / total cells
weeks = true_densities.shape[0]
total_cells = true_densities.shape[1]

# compute distance matrices, cells array
distance_vector = np.asarray(file['distances'])**dist_pow
distance_vector *= 1 / (100**dist_pow)
masks = np.asarray(file['geom']['dynamic_mask']).T.astype(bool)
dtuple = Datatuple(weeks, total_cells, distance_vector, masks)
distance_matrices, masked_densities = mask_input(true_densities, dtuple)
cells = [d.shape[0] for d in masked_densities]

# initialize model
key = hk.PRNGSequence(17)
params = model_forward.init(next(key), cells, weeks, N_PRODUCTS)
loss_fn(params, cells, masked_densities, distance_matrices, obs_weight, dist_weight, ent_weight)

(Array(1.2944885, dtype=float32),
 (Array(0.07623632, dtype=float32),
  Array(125.34545, dtype=float32),
  Array(352.02237, dtype=float32)))

In [101]:
pred_densities, flows = model_forward.apply(params, None, cells, weeks, N_PRODUCTS)
print(cells)
print(list(map(lambda arr: len(arr), pred_densities)))

[675, 702, 749, 784, 879, 954, 1055, 1114, 1133, 1170, 1313, 1438, 1566, 1645, 1631, 1616, 1625, 1602, 1581, 1537, 1474, 1360, 1276, 1269, 1095, 1151, 1307, 1279, 1127, 1156, 1103, 1097, 1022, 1012, 1034, 1042, 1174, 1297, 1241, 1303, 1418, 1608, 1735, 1661, 1438, 1063, 907, 765, 754, 696, 689, 693, 675]
[675, 702, 749, 784, 879, 954, 1055, 1114, 1133, 1170, 1313, 1438, 1566, 1645, 1631, 1616, 1625, 1602, 1581, 1537, 1474, 1360, 1276, 1269, 1095, 1151, 1307, 1279, 1127, 1156, 1103, 1097, 1022, 1012, 1034, 1042, 1174, 1297, 1241, 1303, 1418, 1608, 1735, 1661, 1438, 1063, 907, 765, 754, 696, 689, 693, 675]
