Skip to content

Commit

Permalink
load files according to filename suffix (deepmodeling#1255)
Browse files Browse the repository at this point in the history
The current `try...catch` does not report what's wrong in the YAML file
if it is invalid.

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] committed Jun 25, 2023
1 parent 52b40f9 commit 930f605
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 80 deletions.
20 changes: 3 additions & 17 deletions dpgen/data/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@

import argparse
import glob
import json
import os
import re
import shutil
import subprocess as sp
import sys
import warnings

import dpdata
import numpy as np
Expand All @@ -33,6 +31,7 @@
from dpgen.generator.lib.utils import symlink_user_forward_files
from dpgen.generator.lib.vasp import incar_upper
from dpgen.remote.decide_machine import convert_mdata
from dpgen.util import load_file


def create_path(path, back=False):
Expand Down Expand Up @@ -1465,22 +1464,9 @@ def run_abacus_md(jdata, mdata):


def gen_init_bulk(args):
try:
import ruamel
from monty.serialization import loadfn

warnings.simplefilter("ignore", ruamel.yaml.error.MantissaNoDotYAML1_1Warning)
jdata = loadfn(args.PARAM)
if args.MACHINE is not None:
mdata = loadfn(args.MACHINE)
except Exception:
with open(args.PARAM) as fp:
jdata = json.load(fp)
if args.MACHINE is not None:
with open(args.MACHINE) as fp:
mdata = json.load(fp)

jdata = load_file(args.PARAM)
if args.MACHINE is not None:
mdata = load_file(args.MACHINE)
# Selecting a proper machine
mdata = convert_mdata(mdata, ["fp"])
# disp = make_dispatcher(mdata["fp_machine"])
Expand Down
20 changes: 3 additions & 17 deletions dpgen/data/reaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,16 @@
"""

import glob
import json
import os
import random
import warnings

import dpdata

from dpgen import dlog
from dpgen.dispatcher.Dispatcher import make_submission_compat
from dpgen.generator.run import create_path, make_fp_task_name
from dpgen.remote.decide_machine import convert_mdata
from dpgen.util import normalize, sepline
from dpgen.util import load_file, normalize, sepline

from .arginfo import init_reaction_jdata_arginfo

Expand Down Expand Up @@ -214,20 +212,8 @@ def convert_data(jdata):


def gen_init_reaction(args):
try:
import ruamel
from monty.serialization import loadfn

warnings.simplefilter("ignore", ruamel.yaml.error.MantissaNoDotYAML1_1Warning)
jdata = loadfn(args.PARAM)
if args.MACHINE is not None:
mdata = loadfn(args.MACHINE)
except Exception:
with open(args.PARAM) as fp:
jdata = json.load(fp)
if args.MACHINE is not None:
with open(args.MACHINE) as fp:
mdata = json.load(fp)
jdata = load_file(args.PARAM)
mdata = load_file(args.MACHINE)

jdata_arginfo = init_reaction_jdata_arginfo()
jdata = normalize(jdata_arginfo, jdata)
Expand Down
19 changes: 3 additions & 16 deletions dpgen/data/surf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@

import argparse
import glob
import json
import os
import re
import shutil
import subprocess as sp
import sys
import warnings

import numpy as np
from ase.build import general_surface
Expand All @@ -29,6 +27,7 @@
from dpgen.dispatcher.Dispatcher import make_submission_compat
from dpgen.generator.lib.utils import symlink_user_forward_files
from dpgen.remote.decide_machine import convert_mdata
from dpgen.util import load_file


def create_path(path):
Expand Down Expand Up @@ -602,26 +601,14 @@ def run_vasp_relax(jdata, mdata):


def gen_init_surf(args):
try:
import ruamel
from monty.serialization import loadfn

warnings.simplefilter("ignore", ruamel.yaml.error.MantissaNoDotYAML1_1Warning)
jdata = loadfn(args.PARAM)
if args.MACHINE is not None:
mdata = loadfn(args.MACHINE)
except Exception:
with open(args.PARAM) as fp:
jdata = json.load(fp)
if args.MACHINE is not None:
with open(args.MACHINE) as fp:
mdata = json.load(fp)
jdata = load_file(args.PARAM)

out_dir = out_dir_name(jdata)
jdata["out_dir"] = out_dir
dlog.info("# working dir %s" % out_dir)

if args.MACHINE is not None:
mdata = load_file(args.MACHINE)
# Decide a proper machine
mdata = convert_mdata(mdata, ["fp"])
# disp = make_dispatcher(mdata["fp_machine"])
Expand Down
17 changes: 5 additions & 12 deletions dpgen/generator/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
from dpgen.util import (
convert_training_data_to_hdf5,
expand_sys_str,
load_file,
normalize,
sepline,
set_directory,
Expand Down Expand Up @@ -4494,25 +4495,17 @@ def set_version(mdata):


def run_iter(param_file, machine_file):
try:
import ruamel
from monty.serialization import dumpfn, loadfn

warnings.simplefilter("ignore", ruamel.yaml.error.MantissaNoDotYAML1_1Warning)
jdata = loadfn(param_file)
mdata = loadfn(machine_file)
except Exception:
with open(param_file) as fp:
jdata = json.load(fp)
with open(machine_file) as fp:
mdata = json.load(fp)
jdata = load_file(param_file)
mdata = load_file(machine_file)

jdata_arginfo = run_jdata_arginfo()
jdata = normalize(jdata_arginfo, jdata, strict_check=False)

update_mass_map(jdata)

if jdata.get("pretty_print", False):
from monty.serialization import dumpfn

# assert(jdata["pretty_format"] in ['json','yaml'])
fparam = (
SHORT_CMD
Expand Down
19 changes: 3 additions & 16 deletions dpgen/simplify/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
02: fp (optional, if the original dataset do not have fp data, same as generator)
"""
import glob
import json
import logging
import os
import queue
import warnings
from collections import defaultdict
from typing import List, Union

Expand Down Expand Up @@ -46,7 +44,7 @@
train_task_fmt,
)
from dpgen.remote.decide_machine import convert_mdata
from dpgen.util import expand_sys_str, normalize, sepline
from dpgen.util import expand_sys_str, load_file, normalize, sepline

from .arginfo import simplify_jdata_arginfo

Expand Down Expand Up @@ -433,19 +431,8 @@ def run_iter(param_file, machine_file):
07 run_fp (same as generator)
08 post_fp (same as generator)
"""
# TODO: function of handling input json should be combined as one function
try:
import ruamel
from monty.serialization import loadfn

warnings.simplefilter("ignore", ruamel.yaml.error.MantissaNoDotYAML1_1Warning)
jdata = loadfn(param_file)
mdata = loadfn(machine_file)
except Exception:
with open(param_file) as fp:
jdata = json.load(fp)
with open(machine_file) as fp:
mdata = json.load(fp)
jdata = load_file(param_file)
mdata = load_file(machine_file)

jdata_arginfo = simplify_jdata_arginfo()
jdata = normalize(jdata_arginfo, jdata)
Expand Down
33 changes: 33 additions & 0 deletions dpgen/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,36 @@ def set_directory(path: Path):
yield
finally:
os.chdir(cwd)


def load_file(filename: Union[str, os.PathLike]) -> dict:
"""Load data from a JSON or YAML file.
Parameters
----------
filename : str or os.PathLike
The filename to load data from, whose suffix should be .json, .yaml, or .yml
Returns
-------
dict
The data loaded from the file
Raises
------
ValueError
If the file format is not supported
"""
filename = str(filename)
if filename.endswith(".json"):
with open(filename) as fp:
data = json.load(fp)
elif filename.endswith(".yaml") or filename.endswith(".yml"):
from ruamel.yaml import YAML

yaml = YAML(typ="safe", pure=True)
with open(filename) as fp:
data = yaml.load(fp)
else:
raise ValueError(f"Unsupported file format: {filename}")
return data
1 change: 1 addition & 0 deletions tests/sample.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"aa": "bb"}
1 change: 1 addition & 0 deletions tests/sample.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
aa: bb
18 changes: 18 additions & 0 deletions tests/test_load_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import unittest
from pathlib import Path

from dpgen.util import load_file

this_directory = Path(__file__).parent


class TestLoadFile(unittest.TestCase):
def test_load_json_file(self):
ref = {"aa": "bb"}
jdata = load_file(this_directory / "sample.json")
self.assertEqual(jdata, ref)

def test_load_yaml_file(self):
ref = {"aa": "bb"}
jdata = load_file(this_directory / "sample.yaml")
self.assertEqual(jdata, ref)
4 changes: 2 additions & 2 deletions tests/tools/test_convert_mdata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import os
import sys
import unittest
Expand All @@ -7,6 +6,7 @@
sys.path.insert(0, os.path.join(test_dir, ".."))
__package__ = "tools"
from dpgen.remote.decide_machine import convert_mdata
from dpgen.util import load_file

from .context import setUpModule # noqa: F401

Expand All @@ -15,7 +15,7 @@ class TestConvertMdata(unittest.TestCase):
machine_file = "machine_fp_single.json"

def test_convert_mdata(self):
mdata = json.load(open(self.machine_file))
mdata = load_file(self.machine_file)
mdata = convert_mdata(mdata, ["fp"])
self.assertEqual(mdata["fp_command"], "vasp_std")
self.assertEqual(mdata["fp_group_size"], 8)
Expand Down

0 comments on commit 930f605

Please sign in to comment.