Skip to content

Commit

Permalink
Improved handling of old and new MPRester.
Browse files Browse the repository at this point in the history
  • Loading branch information
Shyue Ping Ong committed Jul 24, 2022
1 parent 604a99d commit 7dfa138
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 30 deletions.
107 changes: 92 additions & 15 deletions pymatgen/ext/matproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class TaskType(Enum):
LDA_STATIC_DIEL = "LDA Static Dielectric"


class MPResterOld:
class _MPResterLegacy:
"""
A class to conveniently interface with the Materials Project REST
interface. The recommended way to use MPRester is with the "with" context
Expand Down Expand Up @@ -176,6 +176,14 @@ def __init__(
is similar to what most web browsers send with each page request.
Set to False to disable the user agent.
"""
warnings.warn(
"You are using the legacy MPRester, which is deprecated. If you are not a power user, ie., you"
"do not have a lot of legacy code that uses many different MPRester methods, it is recommended "
"you get a new API key from the new Materials Project front end. Once you use get your new API"
"key, using the new API key will automatically route you to using the new MPRester."
"to the legacy MPRester.",
DeprecationWarning,
)
if api_key is not None:
self.api_key = api_key
else:
Expand Down Expand Up @@ -364,6 +372,9 @@ def get_materials_ids(self, chemsys_formula):
"""
return self._make_request(f"/materials/{chemsys_formula}/mids", mp_decode=False)

# For backwards compatibility.
get_material_id = get_materials_ids

def get_doc(self, materials_id):
"""
Get the entire data document for one materials id. Use this judiciously.
Expand Down Expand Up @@ -1720,7 +1731,7 @@ def parse_tok(t):
return {"$or": list(map(parse_tok, toks))}


class MPResterNew:
class _MPResterNew:
"""
A new MPRester that supports the new MP API. If you are getting your API key from the new dashboard of MP, you will
need to use this instead of the original MPRester because the new API keys do not work with the old MP API (???!).
Expand Down Expand Up @@ -1829,6 +1840,20 @@ def get_summary_by_material_id(self, material_id: str, fields: list | None = Non
get = "_fields=" + ",".join(fields)
return self.request(f"summary/{material_id}?{get}")["data"][0]

get_doc = get_summary_by_material_id

def get_material_ids(self, formula):
"""
Get all materials ids for a formula.
Args:
formula (str): A formula (e.g., Fe2O3).
Returns:
([str]) List of all materials ids.
"""
return [d["material_id"] for d in self.get_summary({"formula": formula}, fields=["material_id"])]

def get_structure_by_material_id(self, material_id: str, conventional_unit_cell: bool = False) -> Structure:
"""
Get a Structure corresponding to a material_id.
Expand Down Expand Up @@ -1872,19 +1897,71 @@ def get_initial_structures_by_material_id(
return structures


API_KEY = SETTINGS.get("PMG_MAPI_KEY", "")
try:
session = requests.Session()
session.headers = {"x-api-key": API_KEY}
response = session.get("https://api.materialsproject.org/materials/mp-262?_fields=formula_pretty")
if response.status_code != 200:
print("API key not specfied or using old API key. Default to old MPRester")
MPRester = MPResterOld # type: ignore
else:
MPRester = MPResterNew # type: ignore
except Exception:
print("API key not specfied or using old API key. Default to old MPRester")
MPRester = MPResterOld # type: ignore
class MPRester:
"""
A class to conveniently interface with the Materials Project REST
interface. The recommended way to use MPRester is with the "with" context
manager to ensure that sessions are properly closed after usage::
with MPRester("API_KEY") as m:
do_something
MPRester uses the "requests" package, which provides for HTTP connection
pooling. All connections are made via https for security.
For more advanced uses of the Materials API, please consult the API
documentation at https://github.com/materialsproject/mapidoc.
Note that this barebones class is to handle transition between the old and new API keys in a transparent manner,
providing backwards compatibility. Use it as you would with normal MPRester usage. If a new API key is detected,
the _MPResterNew will be initialized. Otherwise, the _MPResterLegacy. At the current moment, full parity between
old and new API MPRester has not been implemented. This will be resolved in the near future. It is not recommended,
but if you would like to select the specific version of the MPRester, you can call initialize either _MPResterNew
or _MPResterLegacy directly.
"""

def __init__(self, *args, **kwargs):
r"""
Args:
*args: Pass through to either legacy or new MPRester.
**kwargs: Pass through to either legacy or new MPRester.
"""
if len(args) > 0:
api_key = args[0]
else:
api_key = kwargs.get("api_key", None)
if api_key is not None:
self.api_key = api_key
else:
self.api_key = SETTINGS.get("PMG_MAPI_KEY", "")

self.session = requests.Session()
self.session.headers = {"x-api-key": self.api_key}
if self.is_new_api():
self.mpr_mapped = _MPResterNew(*args, **kwargs)
else:
self.mpr_mapped = _MPResterLegacy(*args, **kwargs)

def __enter__(self):
return self.mpr_mapped

def __exit__(self, exc_type, exc_val, exc_tb):
self.mpr_mapped.__exit__(exc_type, exc_val, exc_tb)

def is_new_api(self):
session = requests.Session()
session.headers = {"x-api-key": self.api_key}
try:
response = session.get("https://api.materialsproject.org/materials/mp-262?_fields=formula_pretty")
if response.status_code == 200:
return True
else:
return False
except Exception:
return False

def __getattr__(self, name):
return getattr(self.mpr_mapped, name)


class MPRestError(Exception):
Expand Down
30 changes: 15 additions & 15 deletions pymatgen/ext/tests/test_matproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pymatgen.electronic_structure.dos import CompleteDos
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility
from pymatgen.entries.computed_entries import ComputedEntry
from pymatgen.ext.matproj import MP_LOG_FILE, MPResterOld, MPRestError, TaskType
from pymatgen.ext.matproj import MP_LOG_FILE, MPRestError, TaskType, _MPResterLegacy
from pymatgen.io.cif import CifParser
from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine
from pymatgen.phonon.dos import CompletePhononDos
Expand All @@ -42,7 +42,7 @@ class MPResterOldTest(PymatgenTest):
_multiprocess_shared_ = True

def setUp(self):
self.rester = MPResterOld()
self.rester = _MPResterLegacy()
warnings.simplefilter("ignore")

def tearDown(self):
Expand Down Expand Up @@ -129,13 +129,13 @@ def test_get_materials_id_from_task_id(self):

def test_get_materials_id_references(self):
# nosetests pymatgen/matproj/tests/test_matproj.py:MPResterOldTest.test_get_materials_id_references
m = MPResterOld()
m = _MPResterLegacy()
data = m.get_materials_id_references("mp-123")
self.assertTrue(len(data) > 1000)

def test_find_structure(self):
# nosetests pymatgen/matproj/tests/test_matproj.py:MPResterOldTest.test_find_structure
m = MPResterOld()
m = _MPResterLegacy()
ciffile = self.TEST_FILES_DIR / "Fe3O4.cif"
data = m.find_structure(str(ciffile))
self.assertTrue(len(data) > 1)
Expand Down Expand Up @@ -434,34 +434,34 @@ def test_download_info(self):
)

def test_parse_criteria(self):
crit = MPResterOld.parse_criteria("mp-1234 Li-*")
crit = _MPResterLegacy.parse_criteria("mp-1234 Li-*")
self.assertIn("Li-O", crit["$or"][1]["chemsys"]["$in"])
self.assertIn({"task_id": "mp-1234"}, crit["$or"])

crit = MPResterOld.parse_criteria("Li2*")
crit = _MPResterLegacy.parse_criteria("Li2*")
self.assertIn("Li2O", crit["pretty_formula"]["$in"])
self.assertIn("Li2I", crit["pretty_formula"]["$in"])
self.assertIn("CsLi2", crit["pretty_formula"]["$in"])

crit = MPResterOld.parse_criteria("Li-*-*")
crit = _MPResterLegacy.parse_criteria("Li-*-*")
self.assertIn("Li-Re-Ru", crit["chemsys"]["$in"])
self.assertNotIn("Li-Li", crit["chemsys"]["$in"])

comps = MPResterOld.parse_criteria("**O3")["pretty_formula"]["$in"]
comps = _MPResterLegacy.parse_criteria("**O3")["pretty_formula"]["$in"]
for c in comps:
self.assertEqual(len(Composition(c)), 3, f"Failed in {c}")

chemsys = MPResterOld.parse_criteria("{Fe,Mn}-O")["chemsys"]["$in"]
chemsys = _MPResterLegacy.parse_criteria("{Fe,Mn}-O")["chemsys"]["$in"]
self.assertEqual(len(chemsys), 2)
comps = MPResterOld.parse_criteria("{Fe,Mn,Co}O")["pretty_formula"]["$in"]
comps = _MPResterLegacy.parse_criteria("{Fe,Mn,Co}O")["pretty_formula"]["$in"]
self.assertEqual(len(comps), 3, comps)

# Let's test some invalid symbols

self.assertRaises(ValueError, MPResterOld.parse_criteria, "li-fe")
self.assertRaises(ValueError, MPResterOld.parse_criteria, "LO2")
self.assertRaises(ValueError, _MPResterLegacy.parse_criteria, "li-fe")
self.assertRaises(ValueError, _MPResterLegacy.parse_criteria, "LO2")

crit = MPResterOld.parse_criteria("POPO2")
crit = _MPResterLegacy.parse_criteria("POPO2")
self.assertIn("P2O3", crit["pretty_formula"]["$in"])

def test_include_user_agent(self):
Expand All @@ -472,12 +472,12 @@ def test_include_user_agent(self):
headers["user-agent"],
)
self.assertIsNotNone(m, msg=f"Unexpected user-agent value {headers['user-agent']}")
self.rester = MPResterOld(include_user_agent=False)
self.rester = _MPResterLegacy(include_user_agent=False)
self.assertNotIn("user-agent", self.rester.session.headers, msg="user-agent header unwanted")

def test_database_version(self):

with MPResterOld(notify_db_version=True) as mpr:
with _MPResterLegacy(notify_db_version=True) as mpr:
db_version = mpr.get_database_version()

self.assertIsInstance(db_version, str)
Expand Down

0 comments on commit 7dfa138

Please sign in to comment.