# Example usage

Here, we'll just provide a few use-case examples of the codebase to get new users started.


passing models.

In [None]:
import os
import json
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow_addons.activations import mish

## Crystal graph construction

To start, we'll just demonstrate the method for creating crystal graphs suitable for input into contextual message. We'll assume for now, we're working with the materials project 2018 mp.2018.6.1.json file that was provided at https://paperswithcode.com/dataset/materials-project.

In [None]:
# Load the json into memory
with open(r"mp.2018.6.1.json",'r') as f:
  json_data = json.load(f)

# To make things easier, we'll perform this crystal graph construction
# for just the first entry.
# First parse the structure cif string and the target property
structure_string = json_data[0]["structure"]                  # our "X"
target_property  = json_data[0]["formation_energy_per_atom"]  # our "y"

# We provide a preconstructed function in the data_input.graphs for
# converting a structure string to a graph, so we'll use that here.
# For construction customization, we'd recommend looking at the
# source code for mp_s2dgknn. Here, by default, it'll create a
# 48-KNN edge multiplex graph with both real and reciprocal-space
# edges.
from ContextualMPNN.data_input.graphs import mp_s2dgknn
structure_graph = mp_s2dgknn(structure_string)

By using this across the whole dataset, a dataset of graph-target pairs can be generated assuming original data is provided in cif format. The underlying code converts the cif strings to pymatgen Structures, so if cif files cannot be used, pymatgen structures are the next-best option.

## Model creation

Model creation should also be a relatively simple affair as we provide a method for construction of the architecture used during research. The only things necessary are the construction parameters. Here, we'll make a relatively small contextual model with a relatively coarse edge expansion. This can be configured and experimented with to tailor to individual tastes.

In [None]:
from ContextualMPNN.model.contextual import make_context_model

# Set the edge expansion parameters
gaussian_centers	= np.linspace(0.,5.,100)
gaussian_width	  = 0.25
layer_width       = 64

# batch up the parameters, and create the model
params = {"centers" : gaussian_centers,
          "width"   : gaussian_width,
          "C"       : layer_width}
model = make_context_model(**params)

# Now, we'll simply compile it using standard keras.
model.compile("adam","mse")
model.summary()

## Model training

Here, we'll assume we're training a contextual MPNN with support for reciprocal space features, and provide a code snippet to demonstrate how you might go about doing that.

In [None]:
# First, we need to flatten our graph dataset we have.
# As this example we will assume dual graphs and a dataset of more than one
# graph, then we will import and use the appropriate function
from ContextualMPNN.graphs import flatten_dual_dataset
X = flatten_dual_dataset(graphs)
y = np.array(target_property).reshape(-1,1) # Assume single target

# Now we import and create the batch generator keras will use to draw samples
from ContextualMPNN.data_input.batch_generators import DualGraphBatchGenerator
data_generator = DualGraphBatchGenerator(*X, targets = y, batch_size=64)

# Now we simply call model.fit() as we would with any other keras model!
# Due to the DualGraphBatchGenerator implicity pre-batching our graphs,
# it is best we use batch_size=1 to prevent too much memory usage
model.fit(data_generator,
          steps_per_epoch=len(train_generator),
          epochs=10,
          batch_size=1)

Here, we've provided the very fundamental useage cases of the model and framework. In practice, individual projects well require much more tweaking and robust training, and data preparation schemes. Likewise, we haven't covered model creation, but for further details there, it's recommended to look at the source for the `make_context_model(...)` source code.