# Testing jraph

In [1]:
from preprocessing_jraph import get_init_crystal_states
# import ase.db
import jax
import typing
# from preprocessing import get_cutoff_mask, get_init_charges, get_gaussian_distance_encodings, v_center_at_atoms_diagonal, type_to_charges_dict, SYMBOL_MAP
from jax import lax, random, numpy as jnp
import optax
import jraph
from typing import Any, Callable, Sequence, Optional, Tuple


key, subkey = random.split(random.PRNGKey(0))
h_dim = 126
e_dim = 48
layers = [32,32,1] # hidden layers
T = 3
path = "data/SrTiO3_500.db"
n_elems = 3
R_SWITCH = 0.5
R_CUT = 3.0


descriptors, distances, distances_encoded, init_charges, gt_charges, cutoff_mask = get_init_crystal_states(path, edge_encoding_dim = e_dim, SAMPLE_SIZE = None, r_switch = 0.5, r_cut = 5.0) # Change sample size to None if all samples should be read.

  0%|          | 0/500 [00:00<?, ?it/s]

In [2]:
def create_implicitly_batched_graphsTuple_with_encoded_distances(descriptors, distances, distances_encoded, init_charges, cutoff_mask, cutoff = R_CUT):
    batch_size = descriptors.shape[0]
    natom = descriptors.shape[1]
    # Reshaping the descriptors to go over the whole batch
    descriptors = jnp.reshape(descriptors,(batch_size*natom,descriptors.shape[2]))
    
    # to calculate the number of edges for each individual graph
    distances_flattened_batchwise = jnp.reshape(distances,(batch_size,natom*natom))
    n_edges = jnp.count_nonzero(jnp.logical_and(distances_flattened_batchwise > 0, distances_flattened_batchwise < cutoff),axis=1)
    n_nodes = jnp.repeat(jnp.array([natom]),batch_size)
    # Create a flattened index over all previously diagonal elements to be able to delete them from the flattened arrays.
    flatten_idx = jnp.nonzero(jnp.logical_and(distances.flatten() > 0, distances.flatten() < cutoff))[0]
    idx = jnp.nonzero(jnp.logical_and(distances.flatten() > 0, distances.flatten() < cutoff))[0]
    # Make sure that there are only edges between nodes of the same graph
    # Batch range to add onto the tiled outer products
    batch_range = jnp.reshape(jnp.repeat(jnp.arange(batch_size)*natom,natom*natom),(batch_size,natom,natom))
    # outer product over the atoms
    outer = jnp.tile(jnp.outer(jnp.ones(natom),jnp.arange(natom)).astype(jnp.int32),batch_size).reshape(batch_size,natom,natom)
    # transposed for the other variant
    outer_transposed = jnp.transpose(outer, axes=(0,2,1))
    senders = jnp.add(outer_transposed,batch_range).flatten()[flatten_idx]
    receivers = jnp.add(outer,batch_range).flatten()[flatten_idx]
    sender_descriptors = descriptors[senders,:]
    receiver_descriptors = descriptors[senders,:]
    # Encoded distances are also flattened. Combinations of the same node (diagonal) are deleted
    graph_edges = jnp.reshape(distances_encoded,(distances_encoded.shape[0]*distances_encoded.shape[1]*distances_encoded.shape[2],48))[flatten_idx,:]
    # Same for cutoff_mask
    cutoff_mask = cutoff_mask.flatten()[flatten_idx]
    # Nodes contain charges
    # Edges contain concatenation of descriptors, edge_embeddings and cutoff_mask (which will be removed in the Network)
    graph= jraph.GraphsTuple(nodes = init_charges.flatten(),
                            # nodes = jnp.concatenate([descriptors,init_charges],axis=-1), Alternative 
                            senders = senders,
                            receivers = receivers,
                            edges = jnp.concatenate([receiver_descriptors, sender_descriptors, graph_edges, jnp.expand_dims(cutoff_mask,axis=-1)],axis=-1),
                            n_node = n_nodes,
                            n_edge = n_edges,
                            globals = None)
    return graph

size = 100
graph = create_implicitly_batched_graphsTuple_with_encoded_distances(descriptors[:size],distances[:size], distances_encoded[:size],init_charges[:size],cutoff_mask[:size])
graph2 = create_implicitly_batched_graphsTuple_with_encoded_distances(descriptors[size:size+100],distances[size:size+100], distances_encoded[size:size+100],init_charges[size:size+100],cutoff_mask[size:size+100])
# graph = create_implicitly_batched_graphsTuple_with_encoded_distances(descriptors,distances, distances_encoded,init_charges,cutoff_mask)

In [3]:
def print_graph_stats(graph):
    print("Nodes-Shape:",graph[0].shape)
    print("Edges-Shape:",graph[1].shape)
    print("Senders-Length:",graph[2].shape)
    print("Globals:",None if graph[4] is None else graph[4].shape)
    print("n_nodes:",graph[5])
    print("n_edges:",graph[6])

print_graph_stats(graph)

Nodes-Shape: (10500,)
Edges-Shape: (95044, 301)
Senders-Length: (95044,)
Globals: None
n_nodes: [105 105 105 105 105 105 105 105 105 105 105 105 105 105 105 105 105 105
 105 105 105 105 105 105 105 105 105 105 105 105 105 105 105 105 105 105
 105 105 105 105 105 105 105 105 105 105 105 105 105 105 105 105 105 105
 105 105 105 105 105 105 105 105 105 105 105 105 105 105 105 105 105 105
 105 105 105 105 105 105 105 105 105 105 105 105 105 105 105 105 105 105
 105 105 105 105 105 105 105 105 105 105]
n_edges: [958 942 926 954 942 950 950 966 974 958 970 934 962 946 942 958 930 958
 954 950 954 958 946 938 986 958 946 942 930 954 930 950 966 974 958 958
 922 962 958 934 934 938 950 966 946 938 926 954 970 946 946 946 950 934
 942 942 946 986 954 942 958 938 982 962 946 962 918 966 982 962 942 966
 946 946 958 926 946 934 966 926 938 930 930 954 942 934 938 950 958 998
 922 950 930 934 974 974 946 970 974 962]


In [4]:
import jax.tree_util as tree
import haiku as hk

def aggregate_edges_for_nodes_fn(edges: jnp.array,
                                receivers: jnp.array,
                                cutoff_mask: jnp.array,
                                n_nodes: int) -> jnp.array:
  edges = jnp.multiply(edges,cutoff_mask)
  return jax.ops.segment_sum(edges,receivers,n_nodes)


class MLP_haiku(hk.Module):
  def __init__(self, features: jnp.ndarray):
    super().__init__()
    self.features = features

  def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
    layers = []
    for feat in self.features[:-1]:
      layers.append(hk.Linear(feat))
      layers.append(jax.nn.relu)
    layers.append(hk.Linear(self.features[-1]))
    mlp = hk.Sequential(layers)
    return mlp(x)



# Adapted from https://github.com/deepmind/jraph/blob/master/jraph/_src/models.py#L506
def GraphElectronPassing(aggregate_edges_for_nodes_fn: Callable,
                        MLP: Callable,
                        h_dim: int = 126) -> Callable:
  """
  Args:
    update_node_fn: function used to update the nodes. In the paper a single
      layer MLP is used.
    aggregate_edges_for_nodes_fn: function used to aggregates the sender nodes.

  Returns:
    A method that applies a Graph Convolution layer.
  """

  def _ApplyGEP(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    """Applies a Graph Convolution layer."""
    nodes, edges, receivers, senders, _, _, _ = graph
    receiver_descriptors = edges[:,:h_dim]
    sender_descriptors=edges[:,h_dim:h_dim*2]
    graph_edges = edges[:,h_dim*2:-1]
    cutoff_mask = jnp.expand_dims(edges[:,-1],axis=-1)
    sender_charges = jnp.expand_dims(nodes[senders],axis=-1)
    receiver_charges = jnp.expand_dims(nodes[receivers],axis=-1)
    # Neural network forward: NN(q_v, q_w, h_v, h_w, e_vw) from the paper
    edges = jnp.concatenate([receiver_charges, sender_charges, receiver_descriptors, sender_descriptors, edges],axis=-1)
    edges_reversed = jnp.concatenate([sender_charges, receiver_charges, sender_descriptors, receiver_descriptors, edges],axis=-1)
    # Subtraction of both outputs to create electron-passing-output for atom v
    MLP_outputs = jnp.subtract(MLP(edges),MLP(edges_reversed))
    # aggregate_edges_for_nodes_fn is the weighting function with the cutoff_mask
    received_attributes = tree.tree_map(
      lambda e: aggregate_edges_for_nodes_fn(e, receivers, cutoff_mask, nodes.shape[0]), MLP_outputs)
    nodes = received_attributes.flatten()
    return graph._replace(nodes=nodes)
  return _ApplyGEP


gep_layer = GraphElectronPassing(
    aggregate_edges_for_nodes_fn=aggregate_edges_for_nodes_fn,
    MLP = lambda n: MLP_haiku(features=[32,32,32, 1])(n),
)
NUM_PASSES = 2
model = hk.without_apply_rng(hk.transform(gep_layer))
params = model.init(jax.random.PRNGKey(42), graph)
true_labels = gt_charges[:size].flatten()
true_labels_val = gt_charges[size:size+100].flatten()
opt_init, opt_update = optax.adam(1e-2)
opt_state = opt_init(params)

out_graph = model.apply(params,graph)
# out_graph = model.apply(params,out_graph)
# out_graph = model.apply(params,out_graph)


@jax.jit
def rmse_loss(params: hk.Params, graph: jraph.GraphsTuple,  ground_truth: jnp.array, num: int = 2) -> jnp.ndarray:
    # hk.fori_loop(0,3, model.apply, graph, params=params)
    output = model.apply(params, graph)
    for i in range(num-1):
      output = model.apply(params, output)
    # output = model.apply(params, output)
    # output = model.apply(params, output)
    return jnp.sqrt(jnp.sum(jnp.square(output[0]-ground_truth)/len(ground_truth)))

@jax.jit
def update(params: hk.Params, opt_state) -> Tuple[hk.Params, Any]:
    """Returns updated params and state."""
    g = jax.grad(rmse_loss)(params, graph, true_labels)
    updates, opt_state = opt_update(g, opt_state)
    return optax.apply_updates(params, updates), opt_state

print(rmse_loss(params, graph, true_labels),rmse_loss(params, graph2, true_labels_val)) 
for i in range(100):
  params, opt_state = update(params, opt_state)
  if (i % 10)==0: 
    print(rmse_loss(params, graph,true_labels),rmse_loss(params, graph2, true_labels_val)) 

1.5713183 1.5674556
1.9611565 1.9656975
1.0207137 1.0227226
0.26732582 0.26901242
0.15774243 0.16158329
0.11951331 0.12304594
0.1018314 0.10459877
0.08950808 0.09216454
0.080596566 0.083806045
0.07386086 0.077766344
0.06826447 0.072916836


In [2]:
# def create_graphsTuple_with_encoded_distances(descriptors, distances, distances_encoded, init_charges, cutoff_mask, cutoff = R_CUT):
#     natom = descriptors.shape[0]
#     # Create a flattened index over all previously diagonal elements to be able to delete them from the flattened arrays. 
#     flatten_idx = jnp.nonzero(jnp.logical_and(distances > 0, distances < cutoff).flatten())[0]
#     senders = jnp.outer(jnp.ones(natom),jnp.arange(natom)).T.flatten()[flatten_idx].astype(jnp.int32)
#     receivers = jnp.outer(jnp.ones(natom),jnp.arange(natom)).flatten()[flatten_idx].astype(jnp.int32)
#     sender_descriptors = descriptors[senders,:]
#     # print(senders,receivers)
#     receiver_descriptors = descriptors[senders,:]
#     n_nodes = jnp.array([natom])
#     n_edges = jnp.array([natom*natom-natom])
#     # Encoded distances are also flattened. Combinations of the same node (diagonal) are deleted
#     graph_edges = jnp.reshape(distances_encoded,(distances_encoded.shape[0]*distances_encoded.shape[1],48))[flatten_idx,:]
#     # Same for cutoff_mask
#     cutoff_mask = cutoff_mask.flatten()[flatten_idx]
#     # Nodes contain charges
#     # Edges contain concatenation of descriptors, edge_embeddings and cutoff_mask (which will be removed in the Network)
#     graph= jraph.GraphsTuple(nodes = init_charges,
#                             # nodes = jnp.concatenate([descriptors,init_charges],axis=-1), Alternative 
#                             senders = senders,
#                             receivers = receivers,
#                             edges = jnp.concatenate([receiver_descriptors, sender_descriptors, graph_edges, jnp.expand_dims(cutoff_mask,axis=-1)],axis=-1),
#                             n_node = n_nodes,
#                             n_edge = n_edges,
#                             globals = None)
#     return graph

# graph = create_graphsTuple_with_encoded_distances(descriptors[0],distances[0], distances_encoded[0],init_charges[0],cutoff_mask[0])

In [6]:
# natom = 4
# batch_size = 2

# batch_range = jnp.reshape(jnp.repeat(jnp.arange(batch_size)*natom,natom*natom),(batch_size,natom,natom))
# outer = jnp.tile(jnp.outer(jnp.ones(natom),jnp.arange(natom)).astype(jnp.int32),batch_size).reshape(batch_size,natom,natom)
# outer_transposed = jnp.transpose(outer, axes=(0,2,1))
# senders = jnp.add(outer_transposed,batch_range).flatten()
# receivers = jnp.add(outer,batch_range).flatten()

# senders, receivers