Skip to content

Commit

Permalink
ruff unsafe fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Apr 3, 2024
1 parent f14d4f7 commit f55c6fa
Show file tree
Hide file tree
Showing 12 changed files with 43 additions and 55 deletions.
34 changes: 13 additions & 21 deletions custodian/cp2k/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,7 @@ def check(self, directory="./"):
# General catch for SCF not converged
# TODO: should not-static runs allow for some unconverged scf? Leads to issues in my experience
scf = out.data["scf_converged"] or [True]
if not scf[0]:
return True
return False
return bool(not scf[0])

def correct(self, directory="./"):
"""Apply corrections to aid convergence if possible."""
Expand Down Expand Up @@ -488,9 +486,7 @@ def check(self, directory="./"):
try:
out.ran_successfully()
# If job finished, then hung, don't need to wait very long to confirm frozen
if time.time() - st.st_mtime > 300:
return True
return False
return time.time() - st.st_mtime > 300
except ValueError:
pass

Expand Down Expand Up @@ -825,9 +821,7 @@ def check(self, directory="./"):
"""Check for stuck SCF convergence."""
conv = get_conv(os.path.join(directory, self.output_file))
counts = [len([*group]) for _k, group in itertools.groupby(conv)]
if any(cnt > self.max_same for cnt in counts):
return True
return False
return bool(any(cnt > self.max_same for cnt in counts))

def correct(self, directory="/."):
"""Correct issue if possible."""
Expand Down Expand Up @@ -974,9 +968,7 @@ def check(self, directory="./"):
"""Check for unconverged geometry optimization."""
o = Cp2kOutput(os.path.join(directory, self.output_file))
o.convergence()
if o.data.get("geo_opt_not_converged"):
return True
return False
return bool(o.data.get("geo_opt_not_converged"))

def correct(self, directory):
"""Correct issue if possible."""
Expand Down Expand Up @@ -1046,15 +1038,15 @@ def __init__(self, output_file="cp2k.out", enable_checkpointing=True):

def check(self, directory="./"):
"""Check if internal CP2K walltime handler was tripped."""
if regrep(
filename=os.path.join(directory, self.output_file),
patterns={"walltime": r"(exceeded requested execution time)"},
reverse=True,
terminate_on_match=True,
postprocess=bool,
).get("walltime"):
return True
return False
return bool(
regrep(
filename=os.path.join(directory, self.output_file),
patterns={"walltime": "(exceeded requested execution time)"},
reverse=True,
terminate_on_match=True,
postprocess=bool,
).get("walltime")
)

def correct(self, directory="./"):
"""Dump checkpoint info if requested."""
Expand Down
6 changes: 2 additions & 4 deletions custodian/cp2k/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,12 @@ def can_use_ot(output, ci, minimum_band_gap=0.1):
minimum_band_gap (float): the minimum band gap for OT
"""
output.parse_dos()
if (
return bool(
not ci.check("FORCE_EVAL/DFT/SCF/OT")
and not ci.check("FORCE_EVAL/DFT/KPOINTS")
and output.band_gap
and output.band_gap > minimum_band_gap
):
return True
return False
)


def tail(filename, n=10):
Expand Down
6 changes: 4 additions & 2 deletions custodian/custodian.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
ErrorHandlers and Jobs.
"""

from __future__ import annotations

import datetime
import logging
import os
Expand Down Expand Up @@ -662,7 +664,7 @@ def _do_check(self, handlers, terminate_func=None):
raise MaxCorrectionsPerHandlerError(
msg, raises=True, max_errors_per_handler=handler.max_num_corrections, handler=handler
)
logger.warning(msg + " Correction not applied.")
logger.warning(f"{msg} Correction not applied.")
continue
if terminate_func is not None and handler.is_terminating:
logger.info("Terminating job")
Expand Down Expand Up @@ -760,7 +762,7 @@ class ErrorHandler(MSONable):
"actions":[])
"""

max_num_corrections = None
max_num_corrections: int | None = None
raise_on_max = False
"""
Whether corrections from this specific handler should be applied only a
Expand Down
6 changes: 3 additions & 3 deletions custodian/lobster/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def check(self, directory: str = "./") -> bool:
lobsterout = Lobsterout(os.path.join(directory, self.output_filename))
if lobsterout.charge_spilling[0] > self.charge_spilling_limit:
return True
if len(lobsterout.charge_spilling) > 1 and lobsterout.charge_spilling[1] > self.charge_spilling_limit:
return True
return False
return bool(
len(lobsterout.charge_spilling) > 1 and lobsterout.charge_spilling[1] > self.charge_spilling_limit
)
return False
4 changes: 1 addition & 3 deletions custodian/vasp/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1698,9 +1698,7 @@ def check(self, directory="./"):
"""Check for error."""
run_time = datetime.datetime.now() - self.start_time
total_secs = run_time.seconds + run_time.days * 3600 * 24
if total_secs > self.interval:
return True
return False
return total_secs > self.interval

def correct(self, directory="./"):
"""Perform corrections."""
Expand Down
2 changes: 1 addition & 1 deletion custodian/vasp/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@ def postprocess(self, directory="./"):
"""Postprocessing includes renaming and gzipping where necessary."""
# Add suffix to all sub_dir/{items}

neb_dirs, neb_sub = self._get_neb_dirs(directory)
neb_dirs, _neb_sub = self._get_neb_dirs(directory)

for path in neb_dirs:
for file in VASP_NEB_OUTPUT_SUB_FILES:
Expand Down
4 changes: 1 addition & 3 deletions custodian/vasp/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ def check(self, directory="./"):
outcar = load_outcar(os.path.join(directory, "OUTCAR"))
patterns = {"MDALGO": r"MDALGO\s+=\s+([\d]+)"}
outcar.read_pattern(patterns=patterns)
if outcar.data["MDALGO"] == [["3"]]:
return False
return True
return outcar.data["MDALGO"] != [["3"]]


class VaspAECCARValidator(Validator):
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
@pytest.fixture(autouse=True)
def _patch_get_potential_energy(monkeypatch):
"""Monkeypatch the multiprocessing.cpu_count() function to always return 64."""
monkeypatch.setattr(multiprocessing, "cpu_count", lambda *args, **kwargs: 64)
monkeypatch.setattr(multiprocessing, "cpu_count", lambda: 64)
6 changes: 3 additions & 3 deletions tests/cp2k/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
cwd = os.getcwd()


def clean_dir(dir):
for file in glob(os.path.join(dir, "error.*.tar.gz")):
def clean_dir(folder):
for file in glob(os.path.join(folder, "error.*.tar.gz")):
os.remove(file)
for file in glob(os.path.join(dir, "custodian.chk.*.tar.gz")):
for file in glob(os.path.join(folder, "custodian.chk.*.tar.gz")):
os.remove(file)


Expand Down
20 changes: 10 additions & 10 deletions tests/qchem/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def test_OptFF(self):
== QCInput.from_file(os.path.join(SCR_DIR, "test.qin")).as_dict()
)
with pytest.raises(StopIteration):
job.__next__()
next(job)


@skip_if_no_openbabel
Expand Down Expand Up @@ -345,7 +345,7 @@ def test_OptFF(self):
).as_dict()
assert next(job).as_dict() == expected_next
with pytest.raises(StopIteration):
job.__next__()
next(job)


@skip_if_no_openbabel
Expand Down Expand Up @@ -395,7 +395,7 @@ def test_OptFF(self):
== QCInput.from_file(os.path.join(SCR_DIR, "mol.qin")).as_dict()
)
with pytest.raises(StopIteration):
job.__next__()
next(job)


@skip_if_no_openbabel
Expand Down Expand Up @@ -507,7 +507,7 @@ def test_OptFF(self):
== QCInput.from_file(os.path.join(SCR_DIR, "mol.qin")).as_dict()
)
with pytest.raises(StopIteration):
job.__next__()
next(job)


@skip_if_no_openbabel
Expand Down Expand Up @@ -667,7 +667,7 @@ def test_OptFF(self):
== QCInput.from_file(os.path.join(SCR_DIR, "mol.qin")).as_dict()
)
with pytest.raises(StopIteration):
job.__next__()
next(job)


@skip_if_no_openbabel
Expand Down Expand Up @@ -915,7 +915,7 @@ def test_OptFF(self):
os.path.join(SCR_DIR, "mol.qin.freq_2"),
)
with pytest.raises(StopIteration):
job.__next__()
next(job)


@skip_if_no_openbabel
Expand Down Expand Up @@ -971,7 +971,7 @@ def test_OptFF(self):
shutil.copyfile(f"{SCR_DIR}/mol.qin", f"{SCR_DIR}/mol.qin.freq_0")

with pytest.raises(StopIteration):
job.__next__()
next(job)


@skip_if_no_openbabel
Expand Down Expand Up @@ -1022,7 +1022,7 @@ def test_OptFF(self):
== QCInput.from_file(f"{SCR_DIR}/test.qin").as_dict()
)
with pytest.raises(StopIteration):
job.__next__()
next(job)


@skip_if_no_openbabel
Expand Down Expand Up @@ -1094,7 +1094,7 @@ def test_OptFF(self):
)
shutil.copyfile(f"{SCR_DIR}/mol.qin", f"{SCR_DIR}/mol.qin.freq_0")
with pytest.raises(StopIteration):
job.__next__()
next(job)


@skip_if_no_openbabel
Expand Down Expand Up @@ -1201,4 +1201,4 @@ def test_OptFF(self):
shutil.copyfile(f"{SCR_DIR}/mol.qin", f"{SCR_DIR}/mol.qin.freq_1")

with pytest.raises(StopIteration):
job.__next__()
next(job)
4 changes: 2 additions & 2 deletions tests/test_custodian.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class ExampleHandler1b(ExampleHandler):
This handler always can apply a correction, but will only apply it twice before raising.
"""

max_num_corrections = 2 # type: ignore
max_num_corrections = 2
raise_on_max = True


Expand All @@ -84,7 +84,7 @@ class ExampleHandler1c(ExampleHandler):
This handler always can apply a correction, but will only apply it twice and then not anymore.
"""

max_num_corrections = 2 # type: ignore
max_num_corrections = 2
raise_on_max = False


Expand Down
4 changes: 2 additions & 2 deletions tests/vasp/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ def test_check_correct_electronic(self):
assert handler.check()
dct = handler.correct()
assert dct["errors"] == ["Unconverged"]
assert [{"dict": "INCAR", "action": {"_set": {"ALGO": "Damped", "TIME": 0.5}}}] == dct["actions"]
assert dct["actions"] == [{"dict": "INCAR", "action": {"_set": {"ALGO": "Damped", "TIME": 0.5}}}]

def test_check_correct_electronic_repeat(self):
shutil.copy("vasprun.xml.electronic2", "vasprun.xml")
Expand Down Expand Up @@ -714,7 +714,7 @@ def test_amin(self):
handler = UnconvergedErrorHandler()
assert handler.check()
dct = handler.correct()
assert [{"dict": "INCAR", "action": {"_set": {"AMIN": 0.01}}}] == dct["actions"]
assert dct["actions"] == [{"dict": "INCAR", "action": {"_set": {"AMIN": 0.01}}}]

def test_as_from_dict(self):
handler = UnconvergedErrorHandler("random_name.xml")
Expand Down

0 comments on commit f55c6fa

Please sign in to comment.