Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix examples/Calculating MLIP properties.ipynb using ML-relaxed structure for next model in loop #23

Merged
merged 3 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,22 @@ jobs:
- uses: actions/checkout@v4
with:
fetch-depth: 0

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: 3.11

- name: Install dependencies
run: |
pip install ruff mypy

- name: ruff
run: |
ruff --version
ruff check matcalc
ruff format matcalc --check

- name: mypy
run: |
mypy --version
Expand Down
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,22 @@ ci:

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.3
rev: v0.4.4
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: check-yaml
exclude: pymatgen/analysis/vesta_cutoffs.yaml
- id: end-of-file-fixer
- id: trailing-whitespace

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.9.0
rev: v1.10.0
hooks:
- id: mypy

Expand All @@ -35,7 +35,7 @@ repos:
additional_dependencies: [tomli] # needed to read pyproject.toml below py3.11

- repo: https://github.com/MarcoGorelli/cython-lint
rev: v0.16.0
rev: v0.16.2
hooks:
- id: cython-lint
args: [--no-pycodestyle]
Expand All @@ -48,6 +48,6 @@ repos:
args: [--drop-empty-cells, --keep-output]

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.354
rev: v1.1.362
hooks:
- id: pyright
151 changes: 82 additions & 69 deletions examples/Calculating MLIP properties.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,84 +7,73 @@
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"from matcalc.relaxation import RelaxCalc\n",
"from matcalc.phonon import PhononCalc\n",
"from matcalc.eos import EOSCalc\n",
"from matcalc.elasticity import ElasticityCalc\n",
"from matcalc.util import get_universal_calculator\n",
"from datetime import datetime\n",
"from matcalc.utils import get_universal_calculator\n",
"from tqdm import tqdm\n",
"from time import perf_counter\n",
"\n",
"from pymatgen.ext.matproj import MPRester"
"from mp_api.client import MPRester\n",
"\n",
"warnings.filterwarnings(\"ignore\", category=UserWarning, module=\"matgl\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/mp_api/client/mprester.py:182: UserWarning: mpcontribs-client not installed. Install the package to query MPContribs data, or construct pourbaix diagrams: 'pip install mpcontribs-client'\n",
" warnings.warn(\n"
]
}
],
"source": [
"mpr = MPRester()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6e49374bdc1b4ce49b7db471709ea6b6",
"model_id": "6e2826e92227445fa9802d451544cd56",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Retrieving MaterialsDoc documents: 0%| | 0/20627 [00:00<?, ?it/s]"
"Retrieving MaterialsDoc documents: 0%| | 0/100 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"mp_data = mpr.materials._search(nelements=2, fields=[\"material_id\", \"structure\"])"
"mp_data = MPRester().materials.search(\n",
" num_sites=(1, 8), fields=[\"material_id\", \"structure\"], num_chunks=1, chunk_size=100\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"id": "2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CHGNet initialized with 400,438 parameters\n",
"CHGNet will run on cpu\n"
"CHGNet v0.3.0 initialized with 412,525 parameters\n",
"CHGNet will run on mps\n",
"Using Materials Project MACE for MACECalculator with /Users/janosh/.cache/mace/5yyxdm76\n",
"Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.\n",
"Default dtype float32 does not match model dtype float64, converting models to float32.\n"
]
}
],
"source": [
"universal_calcs = [(name, get_universal_calculator(name)) for name in (\"M3GNet\", \"CHGNet\")]"
"models = [(name, get_universal_calculator(name)) for name in (\"M3GNet\", \"CHGNet\", \"MACE\")]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4",
"id": "3",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -95,71 +84,95 @@
{
"cell_type": "code",
"execution_count": null,
"id": "5",
"id": "4",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 60%|████████████████████████████████████████████████████▊ | 12/20 [02:31<01:23, 10.47s/it]/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/dgl/core.py:82: DGLWarning: The input graph for the user-defined edge function does not contain valid edges\n",
" dgl_warning(\n",
" 70%|█████████████████████████████████████████████████████████████▌ | 14/20 [02:35<00:36, 6.13s/it]/Users/shyue/miniconda3/envs/mavrl/lib/python3.9/site-packages/phonopy/structure/cells.py:1396: UserWarning: Crystal structure is distorted in a tricky way so that phonopy could not handle the crystal symmetry properly. It is recommended to symmetrize crystal structure well and then re-start phonon calculation from scratch.\n",
" warnings.warn(msg)\n",
"100%|████████████████████████████████████████████████████████████████████████████████████████| 20/20 [03:34<00:00, 10.75s/it]\n"
"Processing mp-1185285 (Li1 Ac1 Hg2): 0%| | 0/10 [00:00<?, ?it/s]/Users/janosh/dev/matcalc/matcalc/relaxation.py:131: FutureWarning: Import ExpCellFilter from ase.filters\n",
" atoms = ExpCellFilter(atoms)\n",
"/Users/janosh/dev/matcalc/matcalc/relaxation.py:131: FutureWarning: Import ExpCellFilter from ase.filters\n",
" atoms = ExpCellFilter(atoms)\n",
"/Users/janosh/dev/matcalc/matcalc/relaxation.py:131: FutureWarning: Import ExpCellFilter from ase.filters\n",
" atoms = ExpCellFilter(atoms)\n",
"Processing mp-1183106 (Ac2 Zn1 In1): 10%|█ | 1/10 [00:08<01:17, 8.66s/it]/Users/janosh/dev/matcalc/matcalc/relaxation.py:131: FutureWarning: Import ExpCellFilter from ase.filters\n",
" atoms = ExpCellFilter(atoms)\n",
"/Users/janosh/dev/matcalc/matcalc/relaxation.py:131: FutureWarning: Import ExpCellFilter from ase.filters\n",
" atoms = ExpCellFilter(atoms)\n",
"/Users/janosh/dev/matcalc/matcalc/relaxation.py:131: FutureWarning: Import ExpCellFilter from ase.filters\n",
" atoms = ExpCellFilter(atoms)\n",
"Processing mp-976333 (Li2 Ac1 Tl1): 20%|██ | 2/10 [00:15<01:01, 7.68s/it] /Users/janosh/dev/matcalc/matcalc/relaxation.py:131: FutureWarning: Import ExpCellFilter from ase.filters\n",
" atoms = ExpCellFilter(atoms)\n",
"/Users/janosh/dev/matcalc/matcalc/relaxation.py:131: FutureWarning: Import ExpCellFilter from ase.filters\n",
" atoms = ExpCellFilter(atoms)\n",
"/Users/janosh/dev/matcalc/matcalc/relaxation.py:131: FutureWarning: Import ExpCellFilter from ase.filters\n",
" atoms = ExpCellFilter(atoms)\n",
"Processing mp-1006278 (Ac1 Eu1 Au2): 30%|███ | 3/10 [00:23<00:53, 7.68s/it]/Users/janosh/dev/matcalc/matcalc/relaxation.py:131: FutureWarning: Import ExpCellFilter from ase.filters\n",
" atoms = ExpCellFilter(atoms)\n",
"/Users/janosh/dev/matcalc/matcalc/relaxation.py:131: FutureWarning: Import ExpCellFilter from ase.filters\n",
" atoms = ExpCellFilter(atoms)\n",
"/Users/janosh/dev/matcalc/matcalc/relaxation.py:131: FutureWarning: Import ExpCellFilter from ase.filters\n",
" atoms = ExpCellFilter(atoms)\n",
"Processing mp-1183126 (Ac6 Pa2): 40%|████ | 4/10 [00:33<00:50, 8.50s/it] /Users/janosh/dev/matcalc/matcalc/relaxation.py:131: FutureWarning: Import ExpCellFilter from ase.filters\n",
" atoms = ExpCellFilter(atoms)\n",
"/Users/janosh/.venv/py311/lib/python3.11/site-packages/dgl/core.py:82: DGLWarning: The input graph for the user-defined edge function does not contain valid edges\n",
" dgl_warning(\n"
]
}
],
"source": [
"data = []\n",
"prop_preds = []\n",
"\n",
"for dct in (pbar := tqdm(mp_data[:10])): # Here we just do a sampling of 20 structures.\n",
" mat_id, formula = dct.material_id, dct.structure.formula\n",
" pbar.set_description(f\"Running {mat_id} ({formula})\")\n",
" model_preds = {\"material_id\": mat_id, \"formula\": formula, \"nsites\": len(dct.structure)}\n",
"\n",
"for d in tqdm(mp_data[:20]): # Here we just do a sampling of 20 structures.\n",
" s = d.structure\n",
" dd = {\"mid\": d.material_id, \"composition\": s.composition.formula, \"nsites\": len(s)}\n",
" for uc_name, uc in universal_calcs:\n",
" for model_name, model in models:\n",
" # The general principle is to do a relaxation first and just reuse the same structure.\n",
" prop_calcs = [\n",
" (\"relax\", RelaxCalc(uc, fmax=fmax, optimizer=opt)),\n",
" (\"elastic\", ElasticityCalc(uc, fmax=fmax, relax_structure=False)),\n",
" (\"eos\", EOSCalc(uc, fmax=fmax, relax_structure=False, optimizer=opt)),\n",
" (\"phonon\", PhononCalc(uc, fmax=fmax, relax_structure=False)),\n",
" (\"relax\", RelaxCalc(model, fmax=fmax, optimizer=opt)),\n",
" (\"elastic\", ElasticityCalc(model, fmax=fmax, relax_structure=False)),\n",
" (\"eos\", EOSCalc(model, fmax=fmax, relax_structure=False, optimizer=opt)),\n",
" (\"phonon\", PhononCalc(model, fmax=fmax, relax_structure=False)),\n",
" ]\n",
" properties = {}\n",
" for name, c in prop_calcs:\n",
" starttime = datetime.now()\n",
" properties[name] = c.calc(s)\n",
" endtime = datetime.now()\n",
" for name, prop_calc in prop_calcs:\n",
" start_time = perf_counter()\n",
" properties[name] = prop_calc.calc(dct.structure)\n",
" if name == \"relax\":\n",
" # Replace the structure with the one from relaxation for other property computations.\n",
" s = properties[name][\"final_structure\"]\n",
" dd[f\"time_{name}_{uc_name}\"] = (endtime - starttime).total_seconds()\n",
" dd[uc_name] = properties\n",
" data.append(dd)"
" struct = properties[name][\"final_structure\"]\n",
" model_preds[f\"time_{name}_{model_name}\"] = perf_counter() - start_time\n",
" model_preds[model_name] = properties\n",
" prop_preds.append(model_preds)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6",
"id": "5",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"df = pd.DataFrame(data)\n",
"for uc_name, _ in universal_calcs:\n",
" df[f\"time_total_{uc_name}\"] = (\n",
" df[f\"time_relax_{uc_name}\"]\n",
" + df[f\"time_elastic_{uc_name}\"]\n",
" + df[f\"time_phonon_{uc_name}\"]\n",
" + df[f\"time_eos_{uc_name}\"]\n",
"df_preds = pd.DataFrame(prop_preds)\n",
"for model_name, _ in models:\n",
" df_preds[f\"time_total_{model_name}\"] = (\n",
" df_preds[f\"time_relax_{model_name}\"]\n",
" + df_preds[f\"time_elastic_{model_name}\"]\n",
" + df_preds[f\"time_phonon_{model_name}\"]\n",
" + df_preds[f\"time_eos_{model_name}\"]\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7",
"id": "6",
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -705,13 +718,13 @@
}
],
"source": [
"df"
"df_preds"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8",
"id": "7",
"metadata": {},
"outputs": [
{
Expand All @@ -726,13 +739,13 @@
}
],
"source": [
"ax = df.plot(x=\"nsites\", y=\"time_relax_M3GNet\", kind=\"scatter\")"
"ax = df_preds.plot(x=\"nsites\", y=\"time_relax_M3GNet\", kind=\"scatter\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9",
"id": "8",
"metadata": {},
"outputs": [
{
Expand All @@ -747,8 +760,8 @@
}
],
"source": [
"ax = df[\"time_total_M3GNet\"].hist()\n",
"ax = df[\"time_total_CHGNet\"].hist()"
"ax = df_preds[\"time_total_M3GNet\"].hist()\n",
"ax = df_preds[\"time_total_CHGNet\"].hist()"
]
}
],
Expand All @@ -768,7 +781,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
"version": "3.11.7"
}
},
"nbformat": 4,
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ packages = ["matcalc"]
[tool.ruff]
target-version = "py39"
line-length = 120
extend-include = ["*.ipynb"]

[tool.ruff.lint]
select = ["ALL"]
Expand All @@ -76,6 +77,9 @@ exclude = ["docs/conf.py"]
pydocstyle.convention = "google"
isort.required-imports = ["from __future__ import annotations"]

[tool.ruff.format]
docstring-code-format = true

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"]
"tasks.py" = ["ANN", "D", "T203"]
Expand Down
4 changes: 2 additions & 2 deletions requirements-ci.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ coveralls
mypy
ruff
black
dgl==1.1.3
matgl==0.9.1
dgl
matgl
chgnet
Loading