Skip to content

Commit

Permalink
tweak WBM formation energy and convex hull distance histograms
Browse files Browse the repository at this point in the history
add figshare URL for mp_trj_extxyz_by_yuan
lower file size of site/src/figs/hist-wbm-e-form-per-atom.svelte using px.bar instead of px.histogram
move global plot settings from plots.py to __init__.py
  • Loading branch information
janosh committed Nov 27, 2023
1 parent fea24e6 commit 1098aa6
Show file tree
Hide file tree
Showing 11 changed files with 165 additions and 145 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ data/**/tsne
!data/mp/2023-02-07-mp-elemental-reference-entries.json.gz
models/**/checkpoints
data/**/*.json*
data/**/*.gz*
data/**/*.zip*

# slurm + Weights and Biases logs
Expand Down
4 changes: 4 additions & 0 deletions data/figshare/1.0.0.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
"wbm_summary": [
"https://figshare.com/ndownloader/files/41296866",
"2022-10-19-wbm-summary.csv.gz"
],
"mp_trj_extxyz_by_yuan": [
"https://figshare.com/ndownloader/files/43302033",
"2023-11-22-mp-trj-extxyz-by-yuan.zip"
]
},
"article": "https://figshare.com/articles/dataset/22715158",
Expand Down
16 changes: 9 additions & 7 deletions data/mp/eda_mp_trj.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
# %% downloaded mptrj-gga-ggapu.tar.gz from https://drive.google.com/drive/folders/1JQ-ry1RHvNliVg1Ut5OuyUxne51RHiT_
# and extracted the mptrj-gga-ggapu directory (6.2 GB) to data/mp using macOS Finder
# then zipped it to mp-trj-extxyz.zip (also using Finder, 1.6 GB)
zip_path = f"{DATA_DIR}/mp/mp-trj-extxyz-by-yuan.zip"
zip_path = f"{DATA_DIR}/mp/2023-11-22-mp-trj-extxyz-by-yuan.zip"
mp_trj_atoms: dict[str, list[ase.Atoms]] = {}

# extract extXYZ files from zipped directory without unpacking the whole archive
Expand All @@ -56,7 +56,7 @@
mp_trj_atoms[mp_id] = atoms_list


assert len(mp_trj_atoms) == 145_919
assert len(mp_trj_atoms) == 145_919 # number of unique MP IDs


# %%
Expand All @@ -72,7 +72,7 @@
}
).T.convert_dtypes() # convert object columns to float/int where possible
df_mp_trj.index.name = "frame_id"
assert len(df_mp_trj) == 1_580_312
assert len(df_mp_trj) == 1_580_312 # number of total frames
assert formula_col in df_mp_trj

# this is the unrelaxed (but MP2020 corrected) formation energy per atom of the actual
Expand Down Expand Up @@ -108,16 +108,18 @@
f"{data_page}/mp-trj-element-counts-by-occurrence.json", typ="series"
)

excl_elems = "He Ne Ar Kr Xe".split() if (excl_noble := True) else ()
excl_elems = "He Ne Ar Kr Xe".split() if (excl_noble := False) else ()

ax_ptable = ptable_heatmap( # matplotlib version looks better for SI
trj_elem_counts,
fmt=lambda x, _: si_fmt(x, ".1f"),
fmt=lambda x, _: si_fmt(x, ".0f"),
cbar_fmt=lambda x, _: si_fmt(x, ".0f"),
zero_color="#efefef",
log=(log := True),
# drop noble gases
exclude_elements=excl_elems,
exclude_elements=excl_elems, # drop noble gases
cbar_range=None if excl_noble else (2000, None),
label_font_size=17,
value_font_size=14,
)

img_name = f"mp-trj-element-counts-by-occurrence{'-log' if log else ''}"
Expand Down
45 changes: 30 additions & 15 deletions data/wbm/eda_wbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,11 @@


# %%
log = True
for dataset, count_mode, elem_counts in all_counts:
filename = f"{dataset}-element-counts-by-{count_mode}"
if log:
filename += "-log"
elem_counts.to_json(f"{data_page}/{filename}.json")

title = "Number of MP structures containing each element"
Expand All @@ -75,9 +78,14 @@

ax_mp_cnt = ptable_heatmap( # matplotlib version looks better for SI
elem_counts,
fmt=lambda x, _: si_fmt(x, ".1f"),
fmt=lambda x, _: si_fmt(x, ".0f"),
cbar_fmt=lambda x, _: si_fmt(x, ".0f"),
zero_color="#efefef",
label_font_size=17,
value_font_size=14,
cbar_title=f"{dataset.upper()} Element Count",
log=log,
cbar_range=(100, None),
)
save_fig(ax_mp_cnt, f"{PDF_FIGS}/{filename}.pdf")

Expand Down Expand Up @@ -133,31 +141,37 @@


# %% histogram of energy distance to MP convex hull for WBM
col = each_true_col # or e_form_col
mean, std = df_wbm[col].mean(), df_wbm[col].std()
# e_col = each_true_col # or e_form_col
e_col = "e_form_per_atom_uncorrected"
e_col = "e_form_per_atom_mp2020_corrected"
mean, std = df_wbm[e_col].mean(), df_wbm[e_col].std()

range_x = (mean - 2 * std, mean + 2 * std)
counts, bins = np.histogram(df_wbm[col], bins=150, range=range_x)
bins = bins[1:]
counts, bins = np.histogram(df_wbm[e_col], bins=150, range=range_x)
bins = bins[1:] # remove left-most bin edge
left_counts = counts[bins < 0]
right_counts = counts[bins >= 0]

x_label = "WBM energy above MP convex hull (eV/atom)"
assert e_col.startswith(("e_form_per_atom", "e_above_hull"))
x_label = "energy above MP convex hull" if "above" in e_col else "formation energy"
y_label = "Number of Structures"
fig = px.bar(
x=bins[bins < 0], y=left_counts, opacity=0.7, labels={"x": x_label, "y": "Counts"}
x=bins[bins < 0],
y=left_counts,
labels={"x": f"WBM {x_label} (eV/atom)", "y": y_label},
)
fig.add_bar(x=bins[bins >= 0], y=right_counts, opacity=0.7)
fig.add_bar(x=bins[bins >= 0], y=right_counts)
fig.update_traces(width=(bins[1] - bins[0])) # make bars touch

if col.startswith("e_above_hull"):
n_stable = sum(df_wbm[col] <= STABILITY_THRESHOLD)
n_unstable = sum(df_wbm[col] > STABILITY_THRESHOLD)
if e_col.startswith("e_above_hull"):
n_stable = sum(df_wbm[e_col] <= STABILITY_THRESHOLD)
n_unstable = sum(df_wbm[e_col] > STABILITY_THRESHOLD)
assert n_stable + n_unstable == len(df_wbm.dropna())

dummy_mae = (df_wbm[col] - df_wbm[col].mean()).abs().mean()
dummy_mae = (df_wbm[e_col] - df_wbm[e_col].mean()).abs().mean()

title = (
f"n={len(df_wbm.dropna()):,} with {n_stable:,} stable + {n_unstable:,} "
f"unstable, dummy MAE={dummy_mae:.2f}"
f"{len(df_wbm.dropna()):,} structures with {n_stable:,} stable + {n_unstable:,}"
)
fig.layout.title = dict(text=title, x=0.5)

Expand All @@ -170,7 +184,8 @@
(mean + std, f"{mean + std = :.2f}"),
):
anno = dict(text=label, yshift=-10, xshift=-5, xanchor="right")
fig.add_vline(x=x_pos, line=dict(width=1, dash="dash"), annotation=anno)
line_width = 1 if x_pos == mean else 0.5
fig.add_vline(x=x_pos, line=dict(width=line_width, dash="dash"), annotation=anno)

fig.show()

Expand Down
48 changes: 21 additions & 27 deletions data/wbm/fetch_process_wbm_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@
from pymatgen.entries.computed_entries import ComputedStructureEntry
from pymatviz import density_scatter
from pymatviz.io import save_fig
from pymatviz.utils import patch_dict
from tqdm import tqdm

from matbench_discovery import PDF_FIGS, SITE_FIGS, formula_col, id_col, today
from matbench_discovery.data import DATA_FILES
from matbench_discovery.energy import get_e_form_per_atom
from matbench_discovery.plots import pio

try:
import gdown
Expand All @@ -41,8 +39,7 @@


module_dir = os.path.dirname(__file__)

assert pio.templates.default == "plotly_dark+global"
e_form_col = "e_form_per_atom_wbm"


# %% links to google drive files received via email from 1st author Hai-Chen Wang
Expand Down Expand Up @@ -295,7 +292,7 @@ def increment_wbm_material_id(wbm_id: str) -> str:
"nsites": "n_sites",
"vol": "volume",
"e": "uncorrected_energy",
"e_form": "e_form_per_atom_wbm",
"e_form": e_form_col,
"e_hull": "e_above_hull_wbm",
"gap": "bandgap_pbe",
"id": id_col,
Expand Down Expand Up @@ -440,22 +437,28 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:

# %% remove suspicious formation energy outliers
e_form_cutoff = 5
n_too_stable = sum(df_summary.e_form_per_atom_wbm < -e_form_cutoff)
n_too_stable = sum(df_summary[e_form_col] < -e_form_cutoff)
print(f"{n_too_stable = }") # n_too_stable = 502
n_too_unstable = sum(df_summary.e_form_per_atom_wbm > e_form_cutoff)
n_too_unstable = sum(df_summary[e_form_col] > e_form_cutoff)
print(f"{n_too_unstable = }") # n_too_unstable = 22

fig = px.histogram(df_summary, x="e_form_per_atom_wbm", log_y=True, range_x=[-5.5, 5.5])
fig_compressed = False
e_form_hist, e_form_bins = np.histogram(
df_summary[e_form_col], bins=300, range=(-5.5, 5.5)
)
x_label = {e_form_col: "WBM uncorrected formation energy (eV/atom)"}[e_form_col]
fig = px.bar(
x=e_form_bins[:-1], # [:-1] to drop last bin edge which is not needed
y=e_form_hist,
log_y=True,
labels=dict(x=x_label, y="Number of Structures"),
)
fig.update_traces(width=(e_form_bins[1] - e_form_bins[0]), marker_line_width=0)
fig.add_vline(x=e_form_cutoff, line=dict(dash="dash"))
fig.add_vline(x=-e_form_cutoff, line=dict(dash="dash"))
fig.add_annotation(
**dict(x=0, y=1, yref="paper", yshift=20),
text=f"<b>dataset cropped to within +/- {e_form_cutoff} eV/atom</b>",
showarrow=False,
fig.layout.title = dict(
text=f"dataset cropped to within +/- {e_form_cutoff} eV/atom", x=0.5
)
x_axis_title = "WBM uncorrected formation energy (eV/atom)"
fig.update_layout(xaxis_title=x_axis_title, margin=dict(l=10, r=10, t=40, b=10))
fig.layout.margin = dict(l=0, r=0, b=0, t=40)
fig.update_yaxes(fixedrange=True) # disable zooming y-axis
fig.show(
config=dict(
Expand All @@ -466,28 +469,19 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:


# %%
# no need to store all 250k x values in plot, leads to 1.7 MB file, subsample every 10th
# point is enough to see the distribution, round to 3 decimal places to reduce file size
if not fig_compressed:
fig_compressed = True
fig.data[0].x = [round(x, 3) for x in fig.data[0].x[::10]]

img_name = "hist-wbm-e-form-per-atom"
save_fig(fig, f"{SITE_FIGS}/{img_name}.svelte")
# recommended to upload SVG to vecta.io/nano for compression
# save_fig(fig, f"{img_name}.svg", width=800, height=300)

# ensure full data range is visible in PDF (since can't zoom out)
fig.update_layout(xaxis_range=[-12, 82])
# remove title in PDF
with patch_dict(fig.layout, title=""):
save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf")
# fig.update_layout(xaxis_range=[-12, 82]) # if full data range should be visible in PDF
save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf")


# %%
assert len(df_summary) == len(df_wbm) == 257_487

query_str = f"{-e_form_cutoff} < e_form_per_atom_wbm < {e_form_cutoff}"
query_str = f"{-e_form_cutoff} < {e_form_col} < {e_form_cutoff}"
dropped_ids = sorted(set(df_summary.index) - set(df_summary.query(query_str).index))
assert len(dropped_ids) == 502 + 22
assert dropped_ids[:3] == "wbm-1-12142 wbm-1-12143 wbm-1-12144".split()
Expand Down
86 changes: 86 additions & 0 deletions matbench_discovery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
import warnings
from datetime import datetime

import matplotlib.pyplot as plt
import plotly.express as px
import plotly.io as pio
from pymatviz.utils import styled_html_tag

ROOT = os.path.dirname(os.path.dirname(__file__)) # repo root directory
DATA_DIR = f"{ROOT}/data" # directory to store raw data
SITE_FIGS = f"{ROOT}/site/src/figs" # directory for interactive figures
Expand Down Expand Up @@ -47,3 +52,84 @@
# load figshare 1.0.0
with open(f"{FIGSHARE}/1.0.0.json") as file:
FIGSHARE_URLS = json.load(file)


# --- start global plot settings

ev_per_atom = styled_html_tag(
"(eV/atom)", tag="span", style="font-size: 0.8em; font-weight: lighter;"
)
quantity_labels = dict(
n_atoms="Atom Count",
n_elems="Element Count",
crystal_sys="Crystal system",
spg_num="Space group",
n_wyckoff="Number of Wyckoff positions",
n_sites="Lattice site count",
energy_per_atom=f"Energy {ev_per_atom}",
e_form=f"DFT E<sub>form</sub> {ev_per_atom}",
e_above_hull=f"E<sub>hull dist</sub> {ev_per_atom}",
e_above_hull_mp2020_corrected_ppd_mp=f"DFT E<sub>hull dist</sub> {ev_per_atom}",
e_above_hull_pred=f"Predicted E<sub>hull dist</sub> {ev_per_atom}",
e_above_hull_mp=f"E<sub>above MP hull</sub> {ev_per_atom}",
e_above_hull_error=f"Error in E<sub>hull dist</sub> {ev_per_atom}",
vol_diff="Volume difference (A^3)",
e_form_per_atom_mp2020_corrected=f"DFT E<sub>form</sub> {ev_per_atom}",
e_form_per_atom_pred=f"Predicted E<sub>form</sub> {ev_per_atom}",
material_id="Material ID",
band_gap="Band gap (eV)",
formula="Formula",
stress="σ (eV/ų)", # noqa: RUF001
stress_trace="1/3 Tr(σ) (eV/ų)", # noqa: RUF001
)
model_labels = dict(
alignn="ALIGNN",
alignn_ff="ALIGNN FF",
alignn_pretrained="ALIGNN Pretrained",
bowsr_megnet="BOWSR",
chgnet="CHGNet",
chgnet_megnet="CHGNet→MEGNet",
cgcnn_p="CGCNN+P",
cgcnn="CGCNN",
m3gnet_megnet="M3GNet→MEGNet",
m3gnet="M3GNet",
m3gnet_direct="M3GNet DIRECT",
m3gnet_ms="M3GNet MS",
mace="MACE",
megnet="MEGNet",
megnet_rs2re="MEGNet RS2RE",
voronoi_rf="Voronoi RF",
wrenformer="Wrenformer",
pfp="PFP",
dft="DFT",
wbm="WBM",
)
px.defaults.labels = quantity_labels | model_labels


global_layout = dict(
# colorway=px.colors.qualitative.Pastel,
# colorway=colorway,
margin=dict(l=30, r=20, t=60, b=20),
paper_bgcolor="rgba(0,0,0,0)",
# plot_bgcolor="rgba(0,0,0,0)",
font_size=13,
# increase legend marker size and make background transparent
legend=dict(itemsizing="constant", bgcolor="rgba(0, 0, 0, 0)"),
)
pio.templates["global"] = dict(layout=global_layout)
pio.templates.default = "plotly_dark+global"
px.defaults.template = "plotly_dark+global"

# https://github.com/plotly/Kaleido/issues/122#issuecomment-994906924
# when seeing MathJax "loading" message in exported PDFs, try:
# pio.kaleido.scope.mathjax = None


plt.rc("font", size=14)
plt.rc("legend", fontsize=16, title_fontsize=16)
plt.rc("axes", titlesize=16, labelsize=16)
plt.rc("savefig", bbox="tight", dpi=200)
plt.rc("figure", dpi=200, titlesize=16)
plt.rcParams["figure.constrained_layout.use"] = True
# --- end global plot settings
Loading

0 comments on commit 1098aa6

Please sign in to comment.