From 604a99d8bb450f16f68b102f955ab6a3f3b452ae Mon Sep 17 00:00:00 2001 From: Shyue Ping Ong Date: Sun, 24 Jul 2022 10:13:51 -0700 Subject: [PATCH] Make new and old MPRester usage transparent. --- pymatgen/ext/matproj.py | 19 ++++++++++++++-- pymatgen/ext/tests/test_matproj.py | 36 +++++++++++++++--------------- 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/pymatgen/ext/matproj.py b/pymatgen/ext/matproj.py index 90db386dda9..291d1b7b2b3 100644 --- a/pymatgen/ext/matproj.py +++ b/pymatgen/ext/matproj.py @@ -79,7 +79,7 @@ class TaskType(Enum): LDA_STATIC_DIEL = "LDA Static Dielectric" -class MPRester: +class MPResterOld: """ A class to conveniently interface with the Materials Project REST interface. The recommended way to use MPRester is with the "with" context @@ -1720,7 +1720,7 @@ def parse_tok(t): return {"$or": list(map(parse_tok, toks))} -class MPRester2: +class MPResterNew: """ A new MPRester that supports the new MP API. If you are getting your API key from the new dashboard of MP, you will need to use this instead of the original MPRester because the new API keys do not work with the old MP API (???!). @@ -1872,6 +1872,21 @@ def get_initial_structures_by_material_id( return structures +API_KEY = SETTINGS.get("PMG_MAPI_KEY", "") +try: + session = requests.Session() + session.headers = {"x-api-key": API_KEY} + response = session.get("https://api.materialsproject.org/materials/mp-262?_fields=formula_pretty") + if response.status_code != 200: + print("API key not specfied or using old API key. Default to old MPRester") + MPRester = MPResterOld # type: ignore + else: + MPRester = MPResterNew # type: ignore +except Exception: + print("API key not specfied or using old API key. Default to old MPRester") + MPRester = MPResterOld # type: ignore + + class MPRestError(Exception): """ Exception class for MPRestAdaptor. diff --git a/pymatgen/ext/tests/test_matproj.py b/pymatgen/ext/tests/test_matproj.py index 4668511dca2..58a865524f5 100644 --- a/pymatgen/ext/tests/test_matproj.py +++ b/pymatgen/ext/tests/test_matproj.py @@ -22,7 +22,7 @@ from pymatgen.electronic_structure.dos import CompleteDos from pymatgen.entries.compatibility import MaterialsProject2020Compatibility from pymatgen.entries.computed_entries import ComputedEntry -from pymatgen.ext.matproj import MP_LOG_FILE, MPRester, MPRestError, TaskType +from pymatgen.ext.matproj import MP_LOG_FILE, MPResterOld, MPRestError, TaskType from pymatgen.io.cif import CifParser from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine from pymatgen.phonon.dos import CompletePhononDos @@ -38,11 +38,11 @@ (not SETTINGS.get("PMG_MAPI_KEY")) or (not website_is_up), "PMG_MAPI_KEY environment variable not set or MP is down.", ) -class MPResterTest(PymatgenTest): +class MPResterOldTest(PymatgenTest): _multiprocess_shared_ = True def setUp(self): - self.rester = MPRester() + self.rester = MPResterOld() warnings.simplefilter("ignore") def tearDown(self): @@ -128,14 +128,14 @@ def test_get_materials_id_from_task_id(self): self.assertEqual(self.rester.get_materials_id_from_task_id("mp-540081"), "mp-19017") def test_get_materials_id_references(self): - # nosetests pymatgen/matproj/tests/test_matproj.py:MPResterTest.test_get_materials_id_references - m = MPRester() + # nosetests pymatgen/matproj/tests/test_matproj.py:MPResterOldTest.test_get_materials_id_references + m = MPResterOld() data = m.get_materials_id_references("mp-123") self.assertTrue(len(data) > 1000) def test_find_structure(self): - # nosetests pymatgen/matproj/tests/test_matproj.py:MPResterTest.test_find_structure - m = MPRester() + # nosetests pymatgen/matproj/tests/test_matproj.py:MPResterOldTest.test_find_structure + m = MPResterOld() ciffile = self.TEST_FILES_DIR / "Fe3O4.cif" data = m.find_structure(str(ciffile)) self.assertTrue(len(data) > 1) @@ -434,34 +434,34 @@ def test_download_info(self): ) def test_parse_criteria(self): - crit = MPRester.parse_criteria("mp-1234 Li-*") + crit = MPResterOld.parse_criteria("mp-1234 Li-*") self.assertIn("Li-O", crit["$or"][1]["chemsys"]["$in"]) self.assertIn({"task_id": "mp-1234"}, crit["$or"]) - crit = MPRester.parse_criteria("Li2*") + crit = MPResterOld.parse_criteria("Li2*") self.assertIn("Li2O", crit["pretty_formula"]["$in"]) self.assertIn("Li2I", crit["pretty_formula"]["$in"]) self.assertIn("CsLi2", crit["pretty_formula"]["$in"]) - crit = MPRester.parse_criteria("Li-*-*") + crit = MPResterOld.parse_criteria("Li-*-*") self.assertIn("Li-Re-Ru", crit["chemsys"]["$in"]) self.assertNotIn("Li-Li", crit["chemsys"]["$in"]) - comps = MPRester.parse_criteria("**O3")["pretty_formula"]["$in"] + comps = MPResterOld.parse_criteria("**O3")["pretty_formula"]["$in"] for c in comps: self.assertEqual(len(Composition(c)), 3, f"Failed in {c}") - chemsys = MPRester.parse_criteria("{Fe,Mn}-O")["chemsys"]["$in"] + chemsys = MPResterOld.parse_criteria("{Fe,Mn}-O")["chemsys"]["$in"] self.assertEqual(len(chemsys), 2) - comps = MPRester.parse_criteria("{Fe,Mn,Co}O")["pretty_formula"]["$in"] + comps = MPResterOld.parse_criteria("{Fe,Mn,Co}O")["pretty_formula"]["$in"] self.assertEqual(len(comps), 3, comps) # Let's test some invalid symbols - self.assertRaises(ValueError, MPRester.parse_criteria, "li-fe") - self.assertRaises(ValueError, MPRester.parse_criteria, "LO2") + self.assertRaises(ValueError, MPResterOld.parse_criteria, "li-fe") + self.assertRaises(ValueError, MPResterOld.parse_criteria, "LO2") - crit = MPRester.parse_criteria("POPO2") + crit = MPResterOld.parse_criteria("POPO2") self.assertIn("P2O3", crit["pretty_formula"]["$in"]) def test_include_user_agent(self): @@ -472,12 +472,12 @@ def test_include_user_agent(self): headers["user-agent"], ) self.assertIsNotNone(m, msg=f"Unexpected user-agent value {headers['user-agent']}") - self.rester = MPRester(include_user_agent=False) + self.rester = MPResterOld(include_user_agent=False) self.assertNotIn("user-agent", self.rester.session.headers, msg="user-agent header unwanted") def test_database_version(self): - with MPRester(notify_db_version=True) as mpr: + with MPResterOld(notify_db_version=True) as mpr: db_version = mpr.get_database_version() self.assertIsInstance(db_version, str)