Skip to content

Commit

Permalink
Add upload_to_figshare.py and replace local data-loading with cache…
Browse files Browse the repository at this point in the history
…d Figshare downloads (#13)

* use load_train_test() to load wbm-summary in data.py (closes #10)

* fetch_process_wbm_dataset.py wrap urllib.request.urlretrieve in try/except (closes #12)

* add scripts/upload_to_figshare.py for publishing data files to figshare

* add data/figshare dir with readme and FIGSHARE in __init__.py

* change load_train_test() to load files from figshare instead of GitHub (closes #11)

* class Files issue warning when accessing a file path that doesn't exist

* docs recommend --depth 1 for git clone

* fix tests/test_data.py

* add auto-generated data/figshare/1.0.0.json

* pyproject.toml drop unused [tool.setuptools.package-data] matbench_discovery = ["data/mp/*.json"]

* fix AttributeError: 'DataFrame' object has no attribute 'material_id'
  • Loading branch information
janosh committed Jun 20, 2023
1 parent d27f1c4 commit e75f5d8
Show file tree
Hide file tree
Showing 15 changed files with 303 additions and 96 deletions.
10 changes: 10 additions & 0 deletions data/figshare/1.0.0.json
@@ -0,0 +1,10 @@
{
"mp_computed_structure_entries": "https://figshare.com/ndownloader/files/40344436",
"mp_elemental_ref_entries": "https://figshare.com/ndownloader/files/40344445",
"mp_energies": "https://figshare.com/ndownloader/files/40344448",
"mp_patched_phase_diagram": "https://figshare.com/ndownloader/files/40344451",
"wbm_computed_structure_entries": "https://figshare.com/ndownloader/files/40344463",
"wbm_initial_structures": "https://figshare.com/ndownloader/files/40344466",
"wbm_cses_plus_init_structs": "https://figshare.com/ndownloader/files/40344469",
"wbm_summary": "https://figshare.com/ndownloader/files/40344475"
}
3 changes: 3 additions & 0 deletions data/figshare/readme.md
@@ -0,0 +1,3 @@
# Figshare File URLs

Files in this directory are auto-generated by `scripts/upload_to_figshare.py`.
7 changes: 6 additions & 1 deletion data/wbm/fetch_process_wbm_dataset.py
Expand Up @@ -181,7 +181,12 @@ def increment_wbm_material_id(wbm_id: str) -> str:
print(f"{file_path} already exists, skipping")
continue

urllib.request.urlretrieve(f"{mat_cloud_url}&{filename=}", file_path)
try:
url = f"{mat_cloud_url}&filename={filename}"
urllib.request.urlretrieve(url, file_path)
except urllib.error.HTTPError as exc:
print(f"failed to download {url=}: {exc}")
continue


# %%
Expand Down
1 change: 1 addition & 0 deletions matbench_discovery/__init__.py
Expand Up @@ -8,6 +8,7 @@
FIGS = f"{ROOT}/site/src/figs" # directory to store interactive figures
STATIC = f"{ROOT}/site/static/figs" # directory to store static figures
MODELS = f"{ROOT}/site/src/routes/models" # directory to write model analysis
FIGSHARE = f"{ROOT}/data/figshare"

# whether a currently running slurm job is in debug mode
DEBUG = "DEBUG" in os.environ or (
Expand Down
87 changes: 58 additions & 29 deletions matbench_discovery/data.py
@@ -1,6 +1,8 @@
from __future__ import annotations

import json
import os
import sys
import urllib.error
from collections.abc import Sequence
from glob import glob
Expand All @@ -12,13 +14,13 @@
from pymatgen.entries.computed_entries import ComputedStructureEntry
from tqdm import tqdm

from matbench_discovery import ROOT

df_wbm = pd.read_csv(f"{ROOT}/data/wbm/2022-10-19-wbm-summary.csv")
df_wbm.index = df_wbm.material_id
from matbench_discovery import FIGSHARE, ROOT

# repo URL to raw files on GitHub
RAW_REPO_URL = "https://raw.githubusercontent.com/janosh/matbench-discovery"
RAW_REPO_URL = "https://github.com/janosh/matbench-discovery/raw"
figshare_versions = sorted(
x.split(os.path.sep)[-1].split(".json")[0] for x in glob(f"{FIGSHARE}/*.json")
)
# directory to cache downloaded data files
default_cache_dir = os.path.expanduser("~/.cache/matbench-discovery")

Expand All @@ -27,32 +29,47 @@ class Files(dict): # type: ignore
"""Files instance inherits from dict so that .values(), items(), etc. are supported
but also allows accessing attributes by dot notation. E.g. FILES.wbm_summary instead
of FILES["wbm_summary"]. This enables tab completion in IDEs and auto-updating
attribute names across the code base when changing the name of an attribute. Every
subclass must set the _root attribute to a path that serves as the root directory
w.r.t. which all files will be turned into absolute paths. The _key_map attribute
attribute names across the code base when changing the key of a file. Every subclass
must set the _root attribute to a path that serves as the root directory w.r.t.
which all files will be turned into absolute paths. The optional _key_map attribute
can be used to map attribute names to different names in the dict. Useful if you
want to have keys like 'foo+bar' that are not valid Python identifiers.
"""

def __init__(self) -> None:
def __init__(
self, root: str = default_cache_dir, key_map: dict[str, str] = None
) -> None:
"""Create a Files instance."""
key_map = getattr(self, "_key_map", {})
dct = {
key_map.get(key, key): f"{self._root}{file}" # type: ignore
self._not_found_msg: Callable[[str], str] | None = None
rel_paths = {
(key_map or {}).get(key, key): file
for key, file in type(self).__dict__.items()
if not key.startswith("_")
}
self.__dict__ = dct
super().__init__(dct)
abs_paths = {key: f"{root}/{file}" for key, file in rel_paths.items()}
self.__dict__ = abs_paths
super().__init__(abs_paths)

def __getattribute__(self, key: str) -> str:
"""Override __getattr__ to check if file corresponding to key exists."""
file_path = super().__getattribute__(key)
if key in self and not os.path.isfile(file_path):
msg = f"Warning: {file_path!r} associated with {key=} does not exist."
if self._not_found_msg:
msg += f"\n{self._not_found_msg(key)}"
print(msg, file=sys.stderr)
return file_path


class DataFiles(Files):
"""Data files provided by Matbench Discovery.
See https://janosh.github.io/matbench-discovery/contribute for data descriptions.
"""

_root = f"{ROOT}/data/"

_not_found_msg = (
lambda self, key: "You can download it with matbench_discovery." # type: ignore
f"data.load_train_test({key!r}) which will cache the file for future use."
)
mp_computed_structure_entries = (
"mp/2023-02-07-mp-computed-structure-entries.json.gz"
)
Expand All @@ -69,7 +86,9 @@ class DataFiles(Files):
wbm_summary = "wbm/2022-10-19-wbm-summary.csv"


DATA_FILES = DataFiles()
# set root directory for data files to ~/.cache/matbench-discovery/1.x.x/ when
# having downloaded them with matbench_discovery.data.load_train_test()
DATA_FILES = DataFiles(root=f"{ROOT}/data/")


def as_dict_handler(obj: Any) -> dict[str, Any] | None:
Expand All @@ -85,8 +104,8 @@ def as_dict_handler(obj: Any) -> dict[str, Any] | None:


def load_train_test(
data_names: str | Sequence[str] = ("summary",),
version: str = "1.0.0",
data_names: str | Sequence[str],
version: str = figshare_versions[-1],
cache_dir: str | Path = default_cache_dir,
hydrate: bool = False,
**kwargs: Any,
Expand All @@ -96,14 +115,15 @@ def load_train_test(
JSON which will be cached locally to cache_dir for faster re-loading unless
cache_dir is set to None.
See matbench_discovery.data.DATA_FILES for recognized data keys. See
https://janosh.github.io/matbench-discovery/contribute for descriptions.
See matbench_discovery.data.DATA_FILES for recognized data keys. For descriptions,
see https://janosh.github.io/matbench-discovery/contribute#--direct-download.
Args:
data_names (str | list[str], optional): Which parts of the MP/WBM data to load.
Can be any subset of the above data names or 'all'. Defaults to ["summary"].
Can be any subset of set(DATA_FILES) or 'all'.
version (str, optional): Which version of the dataset to load. Defaults to
'1.0.0'. Can be any git tag, branch or commit hash.
latest version of data files published to Figshare. Pass any invalid version
to see valid options.
cache_dir (str, optional): Where to cache data files on local drive. Defaults to
'~/.cache/matbench-discovery'. Set to None to disable caching.
hydrate (bool, optional): Whether to hydrate pymatgen objects. If False,
Expand All @@ -120,6 +140,8 @@ def load_train_test(
pd.DataFrame: Single dataframe or dictionary of dfs if
multiple data were requested.
"""
if version not in figshare_versions:
raise ValueError(f"Unexpected {version=}. Must be one of {figshare_versions}.")
if data_names == "all":
data_names = list(DATA_FILES)
elif isinstance(data_names, str):
Expand All @@ -128,17 +150,21 @@ def load_train_test(
if missing := set(data_names) - set(DATA_FILES):
raise ValueError(f"{missing} must be subset of {set(DATA_FILES)}")

with open(f"{FIGSHARE}/{version}.json") as json_file:
file_urls = json.load(json_file)

dfs = {}
for key in data_names:
file = DataFiles.__dict__[key]
reader = pd.read_csv if file.endswith(".csv") else pd.read_json
csv_ext = (".csv", ".csv.gz", ".csv.bz2")
reader = pd.read_csv if file.endswith(csv_ext) else pd.read_json

cache_path = f"{cache_dir}/{version}/{file}"
if os.path.isfile(cache_path):
print(f"Loading {key!r} from cached file at {cache_path!r}")
df = reader(cache_path, **kwargs)
else:
url = f"{RAW_REPO_URL}/{version}/data/{file}"
url = file_urls[key]
print(f"Downloading {key!r} from {url}")
try:
df = reader(url)
Expand All @@ -147,13 +173,12 @@ def load_train_test(
if cache_dir and not os.path.isfile(cache_path):
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
if ".csv" in file:
df.to_csv(cache_path)
df.to_csv(cache_path, index=False)
elif ".json" in file:
df.reset_index().to_json(
cache_path, default_handler=as_dict_handler
)
df.to_json(cache_path, default_handler=as_dict_handler)
else:
raise ValueError(f"Unexpected file type {file}")
print(f"Cached {key!r} to {cache_path!r}")

df = df.set_index("material_id")
if hydrate:
Expand Down Expand Up @@ -206,3 +231,7 @@ def glob_to_df(
sub_dfs[file] = df

return pd.concat(sub_dfs.values())


df_wbm = load_train_test("wbm_summary")
df_wbm["material_id"] = df_wbm.index
10 changes: 5 additions & 5 deletions matbench_discovery/preds.py
Expand Up @@ -32,9 +32,6 @@ class PredFiles(Files):
See https://janosh.github.io/matbench-discovery/contribute for data descriptions.
"""

_root = f"{ROOT}/models/"
_key_map = model_labels # remap model keys below to pretty plot labels (see Files)

# BOWSR optimizer coupled with original megnet
bowsr_megnet = "bowsr/2023-01-23-bowsr-megnet-wbm-IS2RE.csv"
# default CHGNet model from publication with 400,438 params
Expand Down Expand Up @@ -62,7 +59,8 @@ class PredFiles(Files):
wrenformer = "wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv"


PRED_FILES = PredFiles()
# model_labels remaps model keys to pretty plot labels (see Files)
PRED_FILES = PredFiles(root=f"{ROOT}/models/", key_map=model_labels)


def load_df_wbm_with_preds(
Expand All @@ -87,7 +85,9 @@ def load_df_wbm_with_preds(
pd.DataFrame: WBM summary dataframe with model predictions.
"""
if mismatch := ", ".join(set(models) - set(PRED_FILES)):
raise ValueError(f"Unknown models: {mismatch}")
raise ValueError(
f"Unknown models: {mismatch}, expected subset of {set(PRED_FILES)}"
)

dfs: dict[str, pd.DataFrame] = {}

Expand Down
4 changes: 2 additions & 2 deletions models/chgnet/analyze_chgnet.py
Expand Up @@ -21,9 +21,9 @@


# %%
df_chgnet = pd.read_csv(PRED_FILES.__dict__["CHGNet"])
df_chgnet = pd.read_csv(PRED_FILES.CHGNet)
df_chgnet = df_chgnet.set_index(id_col).add_suffix("_2000")
df_chgnet_500 = pd.read_csv(PRED_FILES.__dict__["CHGNet"].replace("-06", "-04"))
df_chgnet_500 = pd.read_csv(PRED_FILES.CHGNet.replace("-06", "-04"))
df_chgnet_500 = df_chgnet_500.set_index(id_col).add_suffix("_500")
df_chgnet[list(df_chgnet_500)] = df_chgnet_500
df_chgnet["formula"] = df_wbm.formula
Expand Down
6 changes: 3 additions & 3 deletions models/chgnet/ctk_structure_viewer.py
Expand Up @@ -18,14 +18,14 @@
e_form_2000 = "e_form_per_atom_chgnet_2000"
e_form_500 = "e_form_per_atom_chgnet_500"

df_chgnet = pd.read_json(PRED_FILES.__dict__["CHGNet"].replace(".csv", ".json.gz"))
df_chgnet = pd.read_json(PRED_FILES.CHGNet.replace(".csv", ".json.gz"))
df_chgnet = df_chgnet.set_index("material_id")

df_chgnet_2000 = pd.read_csv(PRED_FILES.__dict__["CHGNet"])
df_chgnet_2000 = pd.read_csv(PRED_FILES.CHGNet)
df_chgnet_2000 = df_chgnet_2000.set_index("material_id").add_suffix("_2000")
df_chgnet[list(df_chgnet_2000)] = df_chgnet_2000

df_chgnet_500 = pd.read_csv(PRED_FILES.__dict__["CHGNet"].replace("-06", "-04"))
df_chgnet_500 = pd.read_csv(PRED_FILES.CHGNet.replace("-06", "-04"))
df_chgnet_500 = df_chgnet_500.set_index("material_id").add_suffix("_500")
df_chgnet[list(df_chgnet_500)] = df_chgnet_500

Expand Down
4 changes: 2 additions & 2 deletions models/megnet/test_megnet.py
Expand Up @@ -58,8 +58,8 @@
data_path = {
"IS2RE": DATA_FILES.wbm_initial_structures,
"RS2RE": DATA_FILES.wbm_computed_structure_entries,
"chgnet_structure": PRED_FILES.__dict__["CHGNet"].replace(".csv", ".json.gz"),
"m3gnet_structure": PRED_FILES.__dict__["M3GNet"].replace(".csv", ".json.gz"),
"chgnet_structure": PRED_FILES.CHGNet.replace(".csv", ".json.gz"),
"m3gnet_structure": PRED_FILES.M3GNet.replace(".csv", ".json.gz"),
}[task_type]
print(f"\nJob started running {timestamp}")
print(f"{data_path=}")
Expand Down
4 changes: 1 addition & 3 deletions pyproject.toml
Expand Up @@ -58,9 +58,6 @@ running-models = ["aviary", "m3gnet", "maml", "megnet"]
[tool.setuptools.packages]
find = { include = ["matbench_discovery*"], exclude = ["tests*"] }

[tool.setuptools.package-data]
matbench_discovery = ["data/mp/*.json"]

[tool.distutils.bdist_wheel]
universal = true

Expand Down Expand Up @@ -111,3 +108,4 @@ no_implicit_optional = false
[tool.pytest.ini_options]
testpaths = ["tests"]
addopts = "-p no:warnings"
markers = ["slow: deselect slow tests with -m 'not slow'"]

0 comments on commit e75f5d8

Please sign in to comment.