In [1]:
import numpy as np
from typing import Dict, List
from tensorflow.keras.models import load_model
from megnet.layers import _CUSTOM_OBJECTS
from megnet.data.crystal import CrystalGraph
from megnet.models import GraphModel
from pymatgen.core import Structure
from megnet.utils.data import get_graphs_within_cutoff


class MyCrystalGraph(CrystalGraph):
    def convert(self, structure: Structure, state_attributes: List = None) -> Dict:
        state_attributes = (
            state_attributes or getattr(structure, "state", None) or np.array([[0.0, 0.0]], dtype="float32")
        )
        atoms = self.get_atom_features(structure)
        index1, index2, _, bonds = get_graphs_within_cutoff(structure, self.nn_strategy.cutoff)

        return {"atom": atoms, "bond": bonds, "state": state_attributes, "index1": index1, "index2": index2}
        


nfeat_bond = 10
r_cutoff = 5
gaussian_centers = np.linspace(0, r_cutoff + 1, nfeat_bond)
gaussian_width = 0.5
graph_converter = MyCrystalGraph(cutoff=r_cutoff)

model = load_model('./val_mae_00486_0.031741.hdf5', custom_objects=_CUSTOM_OBJECTS)
model = GraphModel(model=model, graph_converter=graph_converter, centers=gaussian_centers, width=gaussian_width)



In [2]:
from matbench.bench import MatbenchBenchmark


mb = MatbenchBenchmark(autoload=False)

task = mb.tasks_map['matbench_mp_e_form']
task.load()


2023-02-27 21:51:54 INFO     Initialized benchmark 'matbench_v0.1' with 13 tasks: 
['matbench_dielectric',
 'matbench_expt_gap',
 'matbench_expt_is_metal',
 'matbench_glass',
 'matbench_jdft2d',
 'matbench_log_gvrh',
 'matbench_log_kvrh',
 'matbench_mp_e_form',
 'matbench_mp_gap',
 'matbench_mp_is_metal',
 'matbench_perovskites',
 'matbench_phonons',
 'matbench_steels']


INFO:matbench:Initialized benchmark 'matbench_v0.1' with 13 tasks: 
['matbench_dielectric',
 'matbench_expt_gap',
 'matbench_expt_is_metal',
 'matbench_glass',
 'matbench_jdft2d',
 'matbench_log_gvrh',
 'matbench_log_kvrh',
 'matbench_mp_e_form',
 'matbench_mp_gap',
 'matbench_mp_is_metal',
 'matbench_perovskites',
 'matbench_phonons',
 'matbench_steels']


2023-02-27 21:51:54 INFO     Loading dataset 'matbench_mp_e_form'...


INFO:matbench.task:Loading dataset 'matbench_mp_e_form'...


2023-02-27 21:53:55 INFO     Dataset 'matbench_mp_e_form loaded.


INFO:matbench.task:Dataset 'matbench_mp_e_form loaded.


In [3]:
for fold in task.folds:
    # Inputs are either chemical compositions as strings
    # or crystal structures as pymatgen.Structure objects.
    # Outputs are either floats (regression tasks) or bools (classification tasks)
    # train_inputs, train_outputs = task.get_train_and_val_data(fold)

    # train and validate your model
    # my_model.train_and_validate(train_inputs, train_outputs)

    # # Get testing data
    test_inputs = task.get_test_data(fold, include_target=False)

    # # Predict on the testing data
    # # Your output should be a pandas series, numpy array, or python iterable
    # # where the array elements are floats or bools
    predictions = model.predict_structures(test_inputs.values)

    # # Record your data!
    task.record(fold, predictions)

# Save your results
mb.to_file("results.json.gz")

2023-02-27 21:55:31 INFO     Recorded fold matbench_mp_e_form-0 successfully.


INFO:matbench.task:Recorded fold matbench_mp_e_form-0 successfully.


2023-02-27 21:57:07 INFO     Recorded fold matbench_mp_e_form-1 successfully.


INFO:matbench.task:Recorded fold matbench_mp_e_form-1 successfully.


2023-02-27 21:58:44 INFO     Recorded fold matbench_mp_e_form-2 successfully.


INFO:matbench.task:Recorded fold matbench_mp_e_form-2 successfully.


2023-02-27 22:00:22 INFO     Recorded fold matbench_mp_e_form-3 successfully.


INFO:matbench.task:Recorded fold matbench_mp_e_form-3 successfully.


2023-02-27 22:02:01 INFO     Recorded fold matbench_mp_e_form-4 successfully.


INFO:matbench.task:Recorded fold matbench_mp_e_form-4 successfully.


2023-02-27 22:02:08 INFO     Successfully wrote MatbenchBenchmark to file 'results.json.gz'.


INFO:matbench.util:Successfully wrote MatbenchBenchmark to file 'results.json.gz'.
