Skip to content

Commit

Permalink
add per-element MPtrj magmom ptable histogram
Browse files Browse the repository at this point in the history
add missing keys mp_trj_extxyz, mace_checkpoint1, mace_checkpoint2 to DataFiles
  • Loading branch information
janosh committed Nov 29, 2023
1 parent 6781273 commit 13cbb90
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 11 deletions.
54 changes: 45 additions & 9 deletions data/mp/eda_mp_trj.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

# %%
import io
import os
from zipfile import ZipFile

import ase
import ase.io.extxyz
import numpy as np
import pandas as pd
import plotly.express as px
from pymatviz import count_elements, ptable_heatmap, ptable_heatmap_ratio
from pymatviz import count_elements, ptable_heatmap, ptable_heatmap_ratio, ptable_hists
from pymatviz.io import save_fig
from pymatviz.utils import si_fmt
from tqdm import tqdm
Expand Down Expand Up @@ -60,15 +61,15 @@


# %%
info_to_id = lambda info: f"{info['task_id']}-{info['calc_id']}-{info['ionic_step']}"

df_mp_trj = pd.DataFrame(
{
f"{atm.info['task_id']}-{atm.info['calc_id']}-{atm.info['ionic_step']}": {
"formula": str(atm.symbols)
}
| {key: atm.arrays.get(key) for key in ("forces", "magmoms")}
| atm.info
for atoms_list in mp_trj_atoms.values()
for atm in atoms_list
info_to_id(atoms.info): {"formula": str(atoms.symbols)}
| {key: atoms.arrays.get(key) for key in ("forces", "magmoms")}
| atoms.info
for atoms_list in tqdm(mp_trj_atoms.values(), total=len(mp_trj_atoms))
for atoms in atoms_list
}
).T.convert_dtypes() # convert object columns to float/int where possible
df_mp_trj.index.name = "frame_id"
Expand All @@ -86,11 +87,46 @@
df_mp_trj.to_json(f"{DATA_DIR}/mp/mp-trj-2022-09-summary.json.bz2")


# %% load MPtrj summary data
# %% --- load preprocessed MPtrj summary data ---
df_mp_trj = pd.read_json(f"{DATA_DIR}/mp/mp-trj-2022-09-summary.json.bz2")
df_mp_trj.index.name = "frame_id"


# %% plot per-element magmom histograms
magmom_hist_path = f"{DATA_DIR}/mp/mp-trj-2022-09-elem-magmoms.json.bz2"

if os.path.isfile(magmom_hist_path):
mp_trj_elem_magmoms = pd.read_json(magmom_hist_path, typ="series")
elif "mp_trj_elem_magmoms" not in locals():
df_mp_trj_magmom = pd.DataFrame(
{
info_to_id(atoms.info): (
dict(zip(atoms.symbols, atoms.arrays["magmoms"], strict=True))
if magmoms_col in atoms.arrays
else None
)
for frame_id in tqdm(mp_trj_atoms)
for atoms in mp_trj_atoms[frame_id]
}
).T.dropna(axis=0, how="all")

mp_trj_elem_magmoms = {
col: list(df_mp_trj_magmom[col].dropna()) for col in df_mp_trj_magmom
}
pd.Series(mp_trj_elem_magmoms).to_json(magmom_hist_path)

ax = ptable_hists(
mp_trj_elem_magmoms,
symbol_pos=(0.2, 0.8),
log=True,
cbar_title="Magmoms ($μ_B$)",
# annotate each element with its number of magmoms in MPtrj
anno_kwds=dict(text=lambda hist_vals: si_fmt(len(hist_vals), ".0f")),
)

save_fig(ax, f"{PDF_FIGS}/mp-trj-magmoms-ptable-hists.pdf")


# %%
elem_counts: dict[str, dict[str, int]] = {}
for count_mode in ("composition", "occurrence"):
Expand Down
5 changes: 3 additions & 2 deletions data/wbm/eda_wbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,10 @@
filename = f"{dataset}-element-counts-by-{count_mode}"
if log:
filename += "-log"
elem_counts.to_json(f"{data_page}/{filename}.json")
else:
elem_counts.to_json(f"{data_page}/{filename}.json")

title = "Number of MP structures containing each element"
title = f"Number of {dataset.upper()} structures containing each element"
fig = ptable_heatmap_plotly(elem_counts, font_size=10)
fig.layout.title.update(text=title, x=0.4, y=0.9)
fig.show()
Expand Down
5 changes: 5 additions & 0 deletions matbench_discovery/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,11 @@ def _on_not_found(self, key: str, msg: str) -> None: # type: ignore[override]
alignn_checkpoint = "2023-06-02-pbenner-best-alignn-model.pth.zip"
mace_checkpoint = "2023-08-14-mace-yuan-trained-mptrj-04.model"

mp_trj_extxyz = "mp/2023-11-22-mp-trj-extxyz-by-yuan.zip"

mace_checkpoint1 = "2023-08-14-mace-2M-yuan-mptrj-04.model"
mace_checkpoint2 = "2023-10-29-mace-16M-pbenner-mptrj-no-conditional-loss"


# data files can be downloaded and cached with matbench_discovery.data.load()
DATA_FILES = DataFiles()
Expand Down

0 comments on commit 13cbb90

Please sign in to comment.