In [4]:
import h5py
import numpy as np
from mixture_of_products_model_training import Datatuple, mask_input
import optax
import haiku as hk
hdf_src = '/Users/jacobepstein/Documents/work/BirdFlowModels/amewoo_2021_48km.hdf5'
file = h5py.File(hdf_src, 'r')

true_densities = np.asarray(file['distr']).T

weeks = true_densities.shape[0]
total_cells = true_densities.shape[1]

distance_vector = np.asarray(file['distances'])**0.4
distance_vector *= 1 / (100**0.4)
masks = np.asarray(file['geom']['dynamic_mask']).T.astype(bool)

dtuple = Datatuple(weeks, total_cells, distance_vector, masks)
distance_matrices, masked_densities = mask_input(true_densities, dtuple)
cells = [d.shape[0] for d in masked_densities]

# Get the random seed and optimizer
key = hk.PRNGSequence(17)
optimizer = optax.adam(0.1)
max(cells)

1735

In [78]:
import haiku as hk
from jax.nn import softmax
import jax.numpy as jnp
from jax.random import categorical
import jax
import numpy as np

class Product(hk.Module):
    def __init__(self, cells, idx):
        super().__init__()
        self.cells = cells
    
    def __call__(self, t):
        weekly_marginal = hk.get_parameter(
            f'week_{t}',
            (self.cells[t],),
            init=hk.initializers.RandomNormal(),
            dtype='float32'
        )

        return softmax(weekly_marginal, axis=0)


class MixtureOfProductsModel(hk.Module):
    def __init__(self, cells, weeks, n, name="MixtureOfProductsModel", learn_weights=True):
        super().__init__(name=name)
        self.weeks = weeks
        self.cells = cells
        self.n = n # number of product distributions
        self.products = []
        self.learn_weights = learn_weights
    
    def get_prod_k_marginal(self, k, components, tsteps):
        prod_k_marginal = jnp.asarray(1)
        for tstep in tsteps:
            prod_k_marginal = jnp.tensordot(prod_k_marginal, components[tstep][k], axes=0) # indexing with k should be ok now?
        return prod_k_marginal
    
    def get_marginal(self, weights, components, tsteps):
        vectorized_get_prod_k_marginal = hk.vmap(self.get_prod_k_marginal, split_rng=False, in_axes=(1, None, None))
        ks = jnp.array([jnp.arange(self.n)])
        marginals = vectorized_get_prod_k_marginal(ks, components, tsteps) * jnp.array([weights]).T
        return marginals.sum(axis=0)

    def __call__(self):
        if self.learn_weights:
            # initialize weights
            weights = hk.get_parameter(
                'weights',
                (self.n,),
                init=hk.initializers.RandomNormal(),
                dtype='float32'
            )
        else:
            # fix all weights to be equal
            weights = jnp.zeros(self.n)
        weights = softmax(weights, axis=0)
        
        # idea: list of T jnp.arrays of dimension n x cells[t]
        # compute weekly / pairwise marginals from this list
        components = [softmax(hk.get_parameter(f'week_{t}', (self.n, self.cells[t]), init=hk.initializers.RandomNormal(), dtype='float32')) for t in range(self.weeks)]
        
        # TODO: see if we can vmap this as well? (don't think we can)
        single_tstep_marginals = [self.get_marginal(weights, components, [t]) for t in range(self.weeks)]
        pairwise_marginals = [self.get_marginal(weights, components, [t, t+1]) for t in range(self.weeks-1)]

        return single_tstep_marginals, pairwise_marginals

def predict(cells, weeks, n, learn_weights=True):
    model = MixtureOfProductsModel(cells, weeks, n, learn_weights=learn_weights)
    return model()


model_forward = hk.transform(predict)

In [79]:
params = model_forward.init(next(key), [10, 10, 10], 3, 10, learn_weights=True)

In [80]:
params

{'MixtureOfProductsModel': {'weights': Array([-0.5101634 , -0.2055041 ,  0.17628728, -1.2549077 , -1.0884277 ,
          0.7180746 , -1.157557  ,  0.13669363, -0.5092141 ,  1.7665744 ],      dtype=float32),
  'week_0': Array([[-0.7736178 ,  1.4341968 ,  0.0066532 , -1.1540395 , -0.01417581,
           2.0033388 , -0.95001525,  0.11421081, -1.5642594 , -0.732144  ],
         [ 1.1871496 ,  0.00672851,  1.3428135 ,  0.6911836 , -2.2489347 ,
           0.16833398, -0.2334649 , -1.4876806 ,  0.14990021, -2.2604008 ],
         [ 0.36821055,  0.64264745, -0.45789617,  0.5150696 , -1.2723025 ,
          -0.5545435 ,  0.79329073, -1.3114611 ,  0.09347876,  0.6121092 ],
         [ 0.41948295,  0.5762406 , -1.0716459 ,  0.31288287,  0.9346875 ,
          -1.5149761 ,  0.28801233, -0.9442179 , -0.06918854, -1.084417  ],
         [-0.7518748 ,  1.5847896 ,  1.8856518 ,  0.68798393, -3.4653306 ,
           0.47742015,  1.2173004 , -0.8711961 , -0.28522038,  0.9569032 ],
         [ 0.11116216,  1.10

In [81]:
weekly, pairwise = model_forward.apply(params, None, [10, 10, 10], 3, 10)
weekly, pairwise

([Array([[0.05142162, 0.06615908, 0.0697176 , 0.04211472, 0.0344048 ,
          0.05523689, 0.04466748, 0.02726606, 0.02592766, 0.038199  ],
         [0.0697361 , 0.08972248, 0.0945484 , 0.05711442, 0.04665851,
          0.0749102 , 0.06057639, 0.03697721, 0.03516214, 0.05180406],
         [0.10215686, 0.13143507, 0.13850458, 0.08366728, 0.06835034,
          0.10973644, 0.08873873, 0.05416815, 0.05150924, 0.0758881 ],
         [0.02441785, 0.03141602, 0.0331058 , 0.01999842, 0.01633732,
          0.02622955, 0.02121061, 0.01294744, 0.0123119 , 0.01813901],
         [0.0288409 , 0.03710672, 0.03910258, 0.02362093, 0.01929666,
          0.03098077, 0.0250527 , 0.01529274, 0.01454208, 0.02142471],
         [0.17561549, 0.22594695, 0.23809998, 0.14383046, 0.11749949,
          0.18864536, 0.15254867, 0.0931192 , 0.08854831, 0.13045746],
         [0.0269145 , 0.03462821, 0.03649076, 0.02204319, 0.01800775,
          0.02891144, 0.02337933, 0.01427127, 0.01357075, 0.01999367],
         [0.0