In [2]:
import haiku as hk
import h5py
import os
import sys
sys.path.append(os.path.abspath("birdflow/birdflow-bilevel/src/"))
from flow_model_training import loss_fn, mask_input, Datatuple, train_model, w2_loss_fn
from flow_model import model_forward
from hdfs import get_plot_parameters
import numpy as np
import optax
from functools import partial
from jax import jit
import jax.numpy as jnp
from ott.geometry.pointcloud import PointCloud
from ott.geometry.geometry import Geometry
from ott.geometry.costs import CostFn
import jax

In [3]:
hdf_src = 'birdflow/birdflow-bilevel/ebird-data-loading/amewoo_2021_39km.hdf5'
file = h5py.File(hdf_src, 'r+')

true_densities = np.asarray(file['distr']).T

weeks = true_densities.shape[0]
total_cells = true_densities.shape[1]

dist_pow = 0.1
distance_vector = np.asarray(file['distances'])**dist_pow
distance_vector *= 1 / (100**dist_pow) # normalize the distance vector
ncol, nrow, dynamic_masks, big_mask = get_plot_parameters(hdf_src)

dtuple = Datatuple(weeks, ncol, nrow, total_cells, distance_vector, dynamic_masks, big_mask)
print(jnp.sum(jnp.asarray(true_densities[5, :])))
distance_matrices, distance_matrices_for_week, masked_densities = mask_input(true_densities, dtuple)
cells = [d.shape[0] for d in masked_densities]
print(cells)
# Get the random seed and optimizer
key = hk.PRNGSequence(42)
optimizer = optax.adam(1e-3)

# Instantiate loss function
obs_weight = 1
dist_weight = 0.5
ent_weight = 0.5
loss_fn = jit(partial(loss_fn,
                      cells=cells,
                      true_densities=masked_densities, 
                      d_matrices=distance_matrices, 
                      obs_weight=obs_weight, 
                      dist_weight=dist_weight,
                      ent_weight=ent_weight))
w2_loss_fn = partial(w2_loss_fn,
                      cells=cells,
                      true_densities=masked_densities,
                      d_matrices=distance_matrices,
                      d_matrices_for_week=distance_matrices_for_week,
                      obs_weight=obs_weight,
                      dist_weight=dist_weight,
                      ent_weight=ent_weight)


# Run Training and get params and losses
# training_steps = 10
# params, loss_dict = train_model(loss_fn,
#                                 optimizer,
#                                 training_steps,
#                                 cells,
#                                 dtuple.weeks,
#                                 key)

1.0
[1287, 1327, 1383, 1472, 1616, 1735, 1921, 2058, 2031, 2112, 2273, 2441, 2675, 2812, 2792, 2788, 2827, 2768, 2750, 2671, 2568, 2402, 2267, 2240, 1969, 2064, 2303, 2282, 2037, 2095, 2015, 2028, 1916, 1887, 1955, 1944, 2163, 2337, 2264, 2410, 2617, 2924, 3128, 2989, 2729, 2105, 1764, 1437, 1401, 1297, 1305, 1306, 1287]


In [None]:
key = hk.PRNGSequence(42)
params = model_forward.init(next(key), cells, weeks)
pred = model_forward.apply(params, None, cells, weeks)
#standard_loss_val = loss_fn(params)
w2_loss_val = w2_loss_fn(params)
#print(standard_loss_val)
print(w2_loss_val)

computed w2 loss for week 1
computed w2 loss for week 2
computed w2 loss for week 3
computed w2 loss for week 4
computed w2 loss for week 5
computed w2 loss for week 6
computed w2 loss for week 7
computed w2 loss for week 8
computed w2 loss for week 9
computed w2 loss for week 10
computed w2 loss for week 11
computed w2 loss for week 12
computed w2 loss for week 13
computed w2 loss for week 14
computed w2 loss for week 15
computed w2 loss for week 16
computed w2 loss for week 17
computed w2 loss for week 18
computed w2 loss for week 19
computed w2 loss for week 20
computed w2 loss for week 21
computed w2 loss for week 22
computed w2 loss for week 23
computed w2 loss for week 24


In [None]:
a = jnp.linspace(0, 10, 5)
x = a.reshape((a.shape[0], 1)).astype(float)
pc = PointCloud(x, x)

class CustomCostFn(CostFn):
    def __init__(self):
        super().__init__()
        self.cost_matrix = 10 * jnp.ones((10, 10))
    def __call__(self, x, y):
        pass
    def all_pairs(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
        '''
        See https://ott-jax.readthedocs.io/en/latest/_modules/ott/geometry/costs.html#CostFn.all_pairs
        '''
        return self.cost_matrix
custom_cost_mat = 10 * jnp.ones((10, 10))
geom = Geometry(cost_matrix=custom_cost_mat)
print(geom.cost_matrix)

[[10. 10. 10. 10. 10. 10. 10. 10. 10. 10.]
 [10. 10. 10. 10. 10. 10. 10. 10. 10. 10.]
 [10. 10. 10. 10. 10. 10. 10. 10. 10. 10.]
 [10. 10. 10. 10. 10. 10. 10. 10. 10. 10.]
 [10. 10. 10. 10. 10. 10. 10. 10. 10. 10.]
 [10. 10. 10. 10. 10. 10. 10. 10. 10. 10.]
 [10. 10. 10. 10. 10. 10. 10. 10. 10. 10.]
 [10. 10. 10. 10. 10. 10. 10. 10. 10. 10.]
 [10. 10. 10. 10. 10. 10. 10. 10. 10. 10.]
 [10. 10. 10. 10. 10. 10. 10. 10. 10. 10.]]


In [None]:
arr = jnp.array([[1, 2], [3, 4]])

def fn(x, y):
    print(x.shape)
    print(y.shape)
    return arr[x, y]  # Direct indexing works inside vmap

vmapped_fn = jax.vmap(fn)

# Batch inputs
x_batch = jnp.array([0, 1])  # Row indices
y_batch = jnp.array([1, 0])  # Column indices

result = vmapped_fn(jnp.array([[0]]), jnp.array([[1]]))
print(result)

(1,)
(1,)
[[2]]


In [None]:
print(os.getcwd())

/Users/jacobepstein/Documents/work
