In [134]:
import os
import jraph
import haiku as hk
import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np 

# from data import make_graph
from utils import get_R_z_string, print_graph_attributes

tree = jax.tree_util

directed graph => senders == receivers?
-> all nodes that send also receive

edge_k receives from r_k and sends to s_k

In [135]:
prng_seq = hk.PRNGSequence(0)

data_dir = "/Users/Jed.Homer/phd/lfi/jaxdelfi/data/"
resolution = 1024 

redshifts = [0.]
R_values = ["5.0", "10.0", "15.0", "20.0"]

print(f"Running for redshifts:\n\t{redshifts}\nat resolution x={resolution} with R values:\n\t{R_values}.")

Rz_string = get_R_z_string(R_values, redshifts, n_moments_calculate=3)

alpha = jnp.array([
    0.3175, 0.049, 0.6711, 0.9624, 0.834])
parameters = jnp.load(
    os.path.join(data_dir, f"ALL_PDFS_PARAMS.npy")) 
fiducial_dv = jnp.load(
    os.path.join(data_dir, f'fiducial_moments_unflat_{Rz_string}.npy')).mean(axis=0)
simulations = jnp.load(
    os.path.join(data_dir, f"CALCULATED_PDF_MOMENTS_{Rz_string}.npy"))
# simulations = jnp.load(
#     os.path.join(data_dir, f"latin_pdfs_cut_{Rz_string}.npy"))
# fiducial_dv = jnp.load(
#     os.path.join(data_dir, f"fiducial_pdfs_cut_{Rz_string}.npy"))
fiducial_dv = fiducial_dv.mean(axis=0)[jnp.newaxis, :]

n_sims, parameter_dim = parameters.shape
data_dim = np.prod(fiducial_dv.shape)

parameters.shape, simulations.shape, fiducial_dv.shape

Running for redshifts:
	[0.0]
at resolution x=1024 with R values:
	['5.0', '10.0', '15.0', '20.0'].


((2000, 5), (2000, 1, 4, 3), (1, 4, 3))

In [136]:
def make_graph(data):
    n_node_per_graph = np.prod(data.shape)
    n_edge_per_graph = n_node_per_graph ** 2
    node_features = data.reshape(-1, 1)
    cardinality = jnp.array([n_node_per_graph, n_edge_per_graph])
    cardinality = cardinality.astype(jnp.float32)[jnp.newaxis, :]
    g = jraph.get_fully_connected_graph(
        node_features=node_features, 
        n_node_per_graph=n_node_per_graph,
        n_graph=1,
        global_features=cardinality,
        add_self_edges=False)
    return g

In [137]:
key = jr.PRNGKey(0)

graph = make_graph(fiducial_dv)

print_graph_attributes(graph)

nodes (12, 1)
edges None
senders (132,)
receivers (132,)
globals (1, 2)
n_node (1,)
n_edge (1,)


  num_node_features = jax.tree_leaves(node_features)[0].shape[0]
  if n_graph != jax.tree_leaves(global_features)[0].shape[0]:


In [143]:
def update_node_fn(nodes, s_attrs, r_attrs, global_attrs):
    """ 
        Update nodes using previous nodes and global attrs since edge
        features are None.
        -> take average of [nodes, global_attrs] vector
    """
    input_names = ["nodes", "s_attrs", "r_attrs", "global_attrs"]
    print("update node fn")
    inputs = []
    for n, _ in zip(input_names, [nodes, s_attrs, r_attrs, global_attrs]):
        if _ is not None:
            print("\t", n, _.shape)
            inputs.append(_)
        else:
            print("\t", n, None)
    h = jnp.concatenate(inputs, axis=1)
    print("h node update:", h.shape)
    return h.mean(axis=1)[:, None]

def update_global_fn(node_attrs, edge_attrs, globals_):
    """
        Update global attribute using node attrs and
        previous globals.
        -> dot product some weight, immitating a NN,
           with the concatenated node_attrs/globals_.
    """
    input_names = ["node_attrs", "edge_attrs", "globals_"]
    print("update global fn")
    inputs = []
    for n, _ in zip(input_names, [node_attrs, edge_attrs, globals_]):
        if _ is not None:
            print("\t", n, _.shape)
            inputs.append(_)
        else:
            print("\t", n, None)
    h = jnp.concatenate(inputs, axis=1)
    print("h global update:", h.shape)
    weight = jr.normal(key, (h.shape[-1], graph.globals.shape[-1]))
    return jnp.dot(h, weight)
    # return h.mean(axis=1)[:, None]

In [148]:
update_edge_fn = None

rho = jraph.segment_mean

aggregate_edges_for_nodes_fn = rho
aggregate_nodes_for_globals_fn = rho
aggregate_edges_for_globals_fn = rho 

In [155]:
def graph_network(graph):
    nodes, edges, receivers, senders, globals_, n_node, n_edge = graph

    # Equivalent to jnp.sum(n_node), but jittable
    sum_n_node = tree.tree_leaves(nodes)[0].shape[0]
    sum_n_edge = senders.shape[0]
    if not tree.tree_all(
        tree.tree_map(lambda n: n.shape[0] == sum_n_node, nodes)):
        raise ValueError(
            'All node arrays in nest must contain the same number of nodes.')

    # Nodes that connect to each node.
    # -> sent_attributes: get the nodes that send
    # -> received_attributes: get the nodes that receive 
    sent_attributes = tree.tree_map(lambda n: n[senders], nodes)
    received_attributes = tree.tree_map(lambda n: n[receivers], nodes)

    print("s attrs", sent_attributes.shape)
    print("r attrs", received_attributes.shape)
    # print("undirected?", jnp.allclose(sent_attributes, received_attributes)) 
    # print(sent_attributes)
    # print(received_attributes)

    # Here we scatter the global features to the corresponding edges,
    # giving us tensors of shape [num_edges, global_feat].
    global_edge_attributes = tree.tree_map(lambda g: jnp.repeat(
        g, n_edge, axis=0, total_repeat_length=sum_n_edge), globals_)

    if update_edge_fn:
        edges = update_edge_fn(
            edges, 
            sent_attributes, 
            received_attributes,
            global_edge_attributes)

    if update_node_fn:
        print("UPDATE NODE FN", 50 * "~")
        # 
        # No edge features so returns None!?
        sent_attributes = tree.tree_map(
            lambda e: aggregate_edges_for_nodes_fn(
                e, senders, sum_n_node), 
            edges)
        received_attributes = tree.tree_map(
            lambda e: aggregate_edges_for_nodes_fn(
                e, receivers, sum_n_node),
            edges)
        print("s/r attrs", sent_attributes, received_attributes)

        # Here we scatter the global features to the corresponding nodes,
        # giving us tensors of shape [num_nodes, global_feat].
        global_attributes = tree.tree_map(
            lambda g: jnp.repeat(
                g, n_node, axis=0, total_repeat_length=sum_n_node), 
            globals_)

        print("global attrs", global_attributes.shape if global_attributes is not None else global_attributes)

        nodes = update_node_fn(
            nodes, 
            sent_attributes,
            received_attributes, 
            global_attributes)
        print("nodes", nodes.shape)
        print(65 * "~")

    if update_global_fn:
        print("UPDATE GLOBAL FN", 48 * "~")

        n_graph = n_node.shape[0]
        graph_idx = jnp.arange(n_graph)
        print("n_graph", n_graph)
        print("n_node", n_node)
        print("graph_idx", graph_idx.shape)

        # To aggregate nodes and edges from each graph to global features,
        # we first construct tensors that map the node to the corresponding graph.
        # For example, if you have `n_node=[1,2]`, we construct the tensor
        # [0, 1, 1]. We then do the same for edges.
        node_gr_idx = jnp.repeat(
            graph_idx, n_node, axis=0, total_repeat_length=sum_n_node)
        edge_gr_idx = jnp.repeat(
            graph_idx, n_edge, axis=0, total_repeat_length=sum_n_edge)

        # Membership of nodes to graphs in graphs tuple
        print("node idx", node_gr_idx.shape)
        print("edge idx", edge_gr_idx.shape)

        # We use the aggregation function to pool the nodes/edges per graph.
        node_attributes = tree.tree_map(
            lambda n: aggregate_nodes_for_globals_fn(n, node_gr_idx, n_graph),
            nodes)
        edge_attribtutes = tree.tree_map(
            lambda e: aggregate_edges_for_globals_fn(e, edge_gr_idx, n_graph),
            edges)

        print("node attrs", node_attributes.shape)
        print("edge attrs", edge_attribtutes)

        # These pooled nodes are the inputs to the global update fn.
        globals_ = update_global_fn(
            node_attributes, edge_attribtutes, globals_)
        print("globals_", globals_.shape)
        print(65 * "~")

    graph_ = jraph.GraphsTuple(
        nodes=nodes,
        edges=edges,
        receivers=receivers,
        senders=senders,
        globals=globals_,
        n_node=n_node,
        n_edge=n_edge)
    return graph_

In [156]:
graph_ = graph_network(graph)

s attrs (132, 1)
r attrs (132, 1)
undirected? False
[[1.8996811e+01]
 [4.1674698e+02]
 [5.7894862e-01]
 [1.3066869e+00]
 [6.6244755e+00]
 [2.8116238e-01]
 [2.7146104e-01]
 [7.0658302e-01]
 [1.6624273e-01]
 [8.8361159e-02]
 [1.6433717e-01]
 [4.1674698e+02]
 [5.7894862e-01]
 [1.3066869e+00]
 [6.6244755e+00]
 [2.8116238e-01]
 [2.7146104e-01]
 [7.0658302e-01]
 [1.6624273e-01]
 [8.8361159e-02]
 [1.6433717e-01]
 [1.9013317e+00]
 [5.7894862e-01]
 [1.3066869e+00]
 [6.6244755e+00]
 [2.8116238e-01]
 [2.7146104e-01]
 [7.0658302e-01]
 [1.6624273e-01]
 [8.8361159e-02]
 [1.6433717e-01]
 [1.9013317e+00]
 [1.8996811e+01]
 [1.3066869e+00]
 [6.6244755e+00]
 [2.8116238e-01]
 [2.7146104e-01]
 [7.0658302e-01]
 [1.6624273e-01]
 [8.8361159e-02]
 [1.6433717e-01]
 [1.9013317e+00]
 [1.8996811e+01]
 [4.1674698e+02]
 [6.6244755e+00]
 [2.8116238e-01]
 [2.7146104e-01]
 [7.0658302e-01]
 [1.6624273e-01]
 [8.8361159e-02]
 [1.6433717e-01]
 [1.9013317e+00]
 [1.8996811e+01]
 [4.1674698e+02]
 [5.7894862e-01]
 [2.8116238e-

In [151]:
print("\nbefore")
print_graph_attributes(graph)
print("\nafter")
print_graph_attributes(graph_)


before
nodes (12, 1)
edges None
senders (132,)
receivers (132,)
globals (1, 2)
n_node (1,)
n_edge (1,)

after
nodes (12, 1)
edges None
senders (132,)
receivers (132,)
globals (1, 2)
n_node (1,)
n_edge (1,)


In [152]:
graph_ = graph
for n in range(3):
    graph_ = graph_network(graph_)
    print(graph_.globals)

s attrs (132, 1)
r attrs (132, 1)
UPDATE NODE FN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
s/r attrs None None
global attrs (12, 2)
update node fn
	 nodes (12, 1)
	 s_attrs None
	 r_attrs None
	 global_attrs (12, 2)
h node update: (12, 3)
nodes (12, 1)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
UPDATE GLOBAL FN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
n_graph 1
n_node [12]
graph_idx (1,)
node idx (12,)
edge idx (132,)
node attrs (1, 1)
edge attrs None
update global fn
	 node_attrs (1, 1)
	 edge_attrs None
	 globals_ (1, 2)
h global update: (1, 3)
globals_ (1, 2)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[[ 55.101326 -84.6224  ]]
s attrs (132, 1)
r attrs (132, 1)
UPDATE NODE FN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
s/r attrs None None
global attrs (12, 2)
update node fn
	 nodes (12, 1)
	 s_attrs None
	 r_attrs None
	 global_attrs (12, 2)
h node update: (12, 3)
nodes (12, 1)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~