# Introduction

This notebook is written to demonstrate the use of M3GNet as a structure relaxer as well as to provide more comprehensive benchmarks for cubic crystals based on exp data on Wikipedia and MP DFT data. This benchmark is limited to cubic crystals for ease of comparison since there is only one lattice parameter. 

If you are running this notebook from Google Colab, uncomment the next code box to install m3gnet first.

In [None]:
# !pip install m3gnet

In [None]:
from __future__ import annotations

import traceback
import warnings

import numpy as np
import pandas as pd
from pymatgen.core import Composition, Lattice, Structure
from pymatgen.ext.matproj import MPRester
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from tqdm import tqdm

from matgl.apps.pes import Potential
from matgl.ext.ase import Relaxer
from matgl.models import M3GNet

for category in (UserWarning, DeprecationWarning):
    warnings.filterwarnings("ignore", category=category, module="tensorflow")
    warnings.filterwarnings("ignore", category=category, module="pymatgen")

The next cell just compiles data from Wikipedia. 

In [None]:
data = pd.read_html("http://en.wikipedia.org/wiki/Lattice_constant")[0]
data = data[
    ~data["Crystal structure"].isin(
        ["Hexagonal", "Wurtzite", "Wurtzite (HCP)", "Orthorhombic", "Tetragonal perovskite", "Orthorhombic perovskite"]
    )
]
data = data.rename(columns={"Lattice constant (Å)": "a (Å)"})
data = data.drop(columns=["Ref."])
data["a (Å)"] = data["a (Å)"].map(float)
data = data[["Material", "Crystal structure", "a (Å)"]]
data = data[data["Material"] != "NC0.99"]

additional_fcc = """10 Ne 4.43 54 Xe 6.20
13 Al 4.05 58 Ce 5.16
18 Ar 5.26 70 Yb 5.49
20 Ca 5.58 77 Ir 3.84
28 Ni 3.52 78 Pt 3.92
29 Cu 3.61 79 Au 4.08
36 Kr 5.72 82 Pb 4.95
38 Sr 6.08 47 Ag 4.09
45 Rh 3.80 89 Ac 5.31
46 Pd 3.89 90 Th 5.08"""

additional_bcc = """3 Li 3.49 42 Mo 3.15
11 Na 4.23 55 Cs 6.05
19 K 5.23 56 Ba 5.02
23 V 3.02 63 Eu 4.61
24 Cr 2.88 73 Ta 3.31
26 Fe 2.87 74 W 3.16
37 Rb 5.59 41 Nb 3.30"""


def add_new(str_, structure_type, df):
    tokens = str_.split()
    new_crystals = []
    for i in range(int(len(tokens) / 3)):
        el = tokens[3 * i + 1].strip()
        if el not in df["Material"].values:
            new_crystals.append([tokens[3 * i + 1], structure_type, float(tokens[3 * i + 2])])
    df2 = pd.DataFrame(new_crystals, columns=data.columns)
    return pd.concat([df, df2])


data = add_new(additional_fcc, "FCC", data)
data = add_new(additional_bcc, "BCC", data)
data = data.set_index("Material")
print(data)

             Crystal structure     a (Å)
Material                                
C (diamond)      Diamond (FCC)  3.567000
Si               Diamond (FCC)  5.431021
Ge               Diamond (FCC)  5.658000
AlAs         Zinc blende (FCC)  5.660500
AlP          Zinc blende (FCC)  5.451000
...                        ...       ...
K                          BCC  5.230000
Ba                         BCC  5.020000
Eu                         BCC  4.610000
Cr                         BCC  2.880000
Rb                         BCC  5.590000

[92 rows x 2 columns]


In the next cell, we generate an initial structure for all the phases. The cubic constant is set to an arbitrary value of 5 angstroms for all structures. It does not matter too much what you set it to, but it cannot be too large or it will result in isolated atoms due to the cutoffs used in m3gnet to determine bonds. We then call the Relaxer, which is the M3GNet universal IAP pre-trained on the Materials Project.

In [None]:
predicted = []
mp = []
mpr = MPRester()

# load the pre-trained M3GNet model. By default it is the model trained to MP-2021.2.8 database.
model, d = M3GNet.load("M3GNet-MP-2021.2.8-PES", include_json=True)
metadata = d["metadata"]
data_std = metadata["data_std"]
data_mean = metadata["data_mean"]
element_refs = metadata["element_refs"]
# create the potential class
ff = Potential(model, data_std=data_std, element_refs=element_refs)
# create the M3GNet Relaxer
relaxer = Relaxer(potential=ff)

# warnings.filterwarnings(action="ignore", category=UserWarning, module="tensorflow")

for formula, v in tqdm(data.iterrows(), total=len(data)):
    formula = formula.split()[0]
    c = Composition(formula)
    els = sorted(c.elements)
    cs = v["Crystal structure"]

    # We initialize all the crystals with an arbitrary lattice constant of 5 angstroms.
    if "Zinc blende" in cs:
        s = Structure.from_spacegroup("F-43m", Lattice.cubic(5), [els[0], els[1]], [[0, 0, 0], [0.25, 0.25, 0.75]])
    elif "Halite" in cs:
        s = Structure.from_spacegroup("Fm-3m", Lattice.cubic(5), [els[0], els[1]], [[0, 0, 0], [0.5, 0, 0]])
    elif "Caesium chloride" in cs:
        s = Structure.from_spacegroup("Pm-3m", Lattice.cubic(5), [els[0], els[1]], [[0, 0, 0], [0.5, 0.5, 0.5]])
    elif "Cubic perovskite" in cs:
        s = Structure(
            Lattice.cubic(5),
            [els[0], els[1], els[2], els[2], els[2]],
            [[0.0, 0.0, 0.0], [0.5, 0.5, 0.5], [0.5, 0.5, 0], [0.0, 0.5, 0.5], [0.5, 0, 0.5]],
        )
    elif "Diamond" in cs:
        s = Structure.from_spacegroup("Fd-3m", Lattice.cubic(5), [els[0]], [[0.25, 0.75, 0.25]])
    elif "BCC" in cs:
        s = Structure(Lattice.cubic(5), [els[0]] * 2, [[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]])
    elif "FCC" in cs:
        s = Structure(Lattice.cubic(5), [els[0]] * 4, [[0.0, 0.0, 0.0], [0.5, 0.5, 0], [0.0, 0.5, 0.5], [0.5, 0, 0.5]])
    else:
        predicted.append(0)
        mp.append(0)
        continue

    relax_results = relaxer.relax(s, fmax=0.01)

    final_structure = relax_results["final_structure"]

    predicted.append(final_structure.lattice.a)

    try:
        mids = mpr.get_material_ids(s.composition.reduced_formula)
        for i in mids:
            try:
                structure = mpr.get_structure_by_material_id(i)
                sga = SpacegroupAnalyzer(structure)
                sga2 = SpacegroupAnalyzer(final_structure)
                if sga.get_space_group_number() == sga2.get_space_group_number():
                    conv = sga.get_conventional_standard_structure()
                    mp.append(conv.lattice.a)
                    break
            except Exception:
                pass
        else:
            raise RuntimeError
    except Exception:
        mp.append(0)
        traceback.print_exc()

data["MP a (Å)"] = mp
data["Predicted a (Å)"] = predicted

  return th.as_tensor(data, dtype=dtype)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  costheta = torch.tensor(costheta, dtype=torch.complex64)
  phi = torch.tensor(phi, dtype=torch.complex64)
  results = results.type(dtype=DataType.torch_float)
  assert input.numel() == input.storage().size(), (
Traceback (most recent call last):
  File "/var/folders/w6/yrmcztx969j0r2f2v6yy3gp00000gn/T/ipykernel_43497/2881363793.py", line 55, in <module>
    mids = mpr.get_material_ids(s.composition.reduced_formula)
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/mprester.py", line 350, in get_material_ids
    for doc in self.materials.search(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/routes/materials.py", line 179, in search
    return super()._search(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/core/client.py", line 911, in _search
    return self._get_a

  3%|███▋                                                                                                             | 3/92 [00:05<02:46,  1.87s/it]Traceback (most recent call last):
  File "/var/folders/w6/yrmcztx969j0r2f2v6yy3gp00000gn/T/ipykernel_43497/2881363793.py", line 55, in <module>
    mids = mpr.get_material_ids(s.composition.reduced_formula)
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/mprester.py", line 350, in get_material_ids
    for doc in self.materials.search(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/routes/materials.py", line 179, in search
    return super()._search(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/core/client.py", line 911, in _search
    return self._get_all_documents(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/core/client.py", line 960, in _get_all_documents
    results = self._query_re

  8%|████████▌                                                                                                        | 7/92 [00:10<01:45,  1.24s/it]Traceback (most recent call last):
  File "/var/folders/w6/yrmcztx969j0r2f2v6yy3gp00000gn/T/ipykernel_43497/2881363793.py", line 55, in <module>
    mids = mpr.get_material_ids(s.composition.reduced_formula)
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/mprester.py", line 350, in get_material_ids
    for doc in self.materials.search(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/routes/materials.py", line 179, in search
    return super()._search(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/core/client.py", line 911, in _search
    return self._get_all_documents(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/core/client.py", line 960, in _get_all_documents
    results = self._query_re

 12%|█████████████▍                                                                                                  | 11/92 [00:14<01:11,  1.13it/s]Traceback (most recent call last):
  File "/var/folders/w6/yrmcztx969j0r2f2v6yy3gp00000gn/T/ipykernel_43497/2881363793.py", line 55, in <module>
    mids = mpr.get_material_ids(s.composition.reduced_formula)
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/mprester.py", line 350, in get_material_ids
    for doc in self.materials.search(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/routes/materials.py", line 179, in search
    return super()._search(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/core/client.py", line 911, in _search
    return self._get_all_documents(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/core/client.py", line 960, in _get_all_documents
    results = self._query_re

 16%|██████████████████▎                                                                                             | 15/92 [00:17<01:10,  1.09it/s]Traceback (most recent call last):
  File "/var/folders/w6/yrmcztx969j0r2f2v6yy3gp00000gn/T/ipykernel_43497/2881363793.py", line 55, in <module>
    mids = mpr.get_material_ids(s.composition.reduced_formula)
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/mprester.py", line 350, in get_material_ids
    for doc in self.materials.search(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/routes/materials.py", line 179, in search
    return super()._search(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/core/client.py", line 911, in _search
    return self._get_all_documents(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/core/client.py", line 960, in _get_all_documents
    results = self._query_re

 21%|███████████████████████▏                                                                                        | 19/92 [00:23<01:35,  1.31s/it]Traceback (most recent call last):
  File "/var/folders/w6/yrmcztx969j0r2f2v6yy3gp00000gn/T/ipykernel_43497/2881363793.py", line 55, in <module>
    mids = mpr.get_material_ids(s.composition.reduced_formula)
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/mprester.py", line 350, in get_material_ids
    for doc in self.materials.search(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/routes/materials.py", line 179, in search
    return super()._search(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/core/client.py", line 911, in _search
    return self._get_all_documents(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/core/client.py", line 960, in _get_all_documents
    results = self._query_re

 25%|████████████████████████████                                                                                    | 23/92 [00:33<02:33,  2.22s/it]Traceback (most recent call last):
  File "/var/folders/w6/yrmcztx969j0r2f2v6yy3gp00000gn/T/ipykernel_43497/2881363793.py", line 55, in <module>
    mids = mpr.get_material_ids(s.composition.reduced_formula)
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/mprester.py", line 350, in get_material_ids
    for doc in self.materials.search(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/routes/materials.py", line 179, in search
    return super()._search(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/core/client.py", line 911, in _search
    return self._get_all_documents(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/core/client.py", line 960, in _get_all_documents
    results = self._query_re

 29%|████████████████████████████████▊                                                                               | 27/92 [00:37<01:22,  1.27s/it]Traceback (most recent call last):
  File "/var/folders/w6/yrmcztx969j0r2f2v6yy3gp00000gn/T/ipykernel_43497/2881363793.py", line 55, in <module>
    mids = mpr.get_material_ids(s.composition.reduced_formula)
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/mprester.py", line 350, in get_material_ids
    for doc in self.materials.search(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/routes/materials.py", line 179, in search
    return super()._search(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/core/client.py", line 911, in _search
    return self._get_all_documents(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/core/client.py", line 960, in _get_all_documents
    results = self._query_re

 34%|█████████████████████████████████████▋                                                                          | 31/92 [00:41<01:09,  1.14s/it]Traceback (most recent call last):
  File "/var/folders/w6/yrmcztx969j0r2f2v6yy3gp00000gn/T/ipykernel_43497/2881363793.py", line 55, in <module>
    mids = mpr.get_material_ids(s.composition.reduced_formula)
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/mprester.py", line 350, in get_material_ids
    for doc in self.materials.search(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/routes/materials.py", line 179, in search
    return super()._search(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/core/client.py", line 911, in _search
    return self._get_all_documents(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/core/client.py", line 960, in _get_all_documents
    results = self._query_re

 38%|██████████████████████████████████████████▌                                                                     | 35/92 [00:44<00:42,  1.35it/s]Traceback (most recent call last):
  File "/var/folders/w6/yrmcztx969j0r2f2v6yy3gp00000gn/T/ipykernel_43497/2881363793.py", line 55, in <module>
    mids = mpr.get_material_ids(s.composition.reduced_formula)
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/mprester.py", line 350, in get_material_ids
    for doc in self.materials.search(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/routes/materials.py", line 179, in search
    return super()._search(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/core/client.py", line 911, in _search
    return self._get_all_documents(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/core/client.py", line 960, in _get_all_documents
    results = self._query_re

 42%|███████████████████████████████████████████████▍                                                                | 39/92 [00:47<00:42,  1.25it/s]Traceback (most recent call last):
  File "/var/folders/w6/yrmcztx969j0r2f2v6yy3gp00000gn/T/ipykernel_43497/2881363793.py", line 55, in <module>
    mids = mpr.get_material_ids(s.composition.reduced_formula)
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/mprester.py", line 350, in get_material_ids
    for doc in self.materials.search(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/routes/materials.py", line 179, in search
    return super()._search(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/core/client.py", line 911, in _search
    return self._get_all_documents(
  File "/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/core/client.py", line 960, in _get_all_documents
    results = self._query_re

DGLError: Expect number of features to match number of nodes (len(u)). Got 2 and 0 instead.

In [None]:
data["% error vs Expt"] = (data["Predicted a (Å)"] - data["a (Å)"]) / data["a (Å)"]
data["% error vs MP"] = (data["Predicted a (Å)"] - data["MP a (Å)"]) / data["MP a (Å)"]

In [None]:
data.sort_index().style.format({"% error vs Expt": "{:,.2%}", "% error vs MP": "{:,.2%}"}).background_gradient()

In [None]:
data["% error vs MP"].replace([np.inf, -np.inf], np.nan).dropna().hist(bins=20)