Skip to content

Commit

Permalink
Merge pull request #82 from muon-spectroscopy-computational-project/7…
Browse files Browse the repository at this point in the history
…3_mutable_default_arguments

Replace instances of mutable default arguments #73
  • Loading branch information
patrick-austin committed Jun 9, 2023
2 parents e35c8b2 + 937eb03 commit 00d00fe
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 28 deletions.
7 changes: 4 additions & 3 deletions pymuonsuite/dipolar/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,12 @@ def __init__(
self,
atoms,
mu_pos,
isotopes={},
isotopes=None,
isotope_list=None,
cutoff=10,
overlap_eps=1e-3,
):
isotopes = {} if isotopes is None else isotopes

# Get positions, cell, and species, only things we care about
self.cell = np.array(atoms.get_cell())
Expand Down Expand Up @@ -147,7 +148,7 @@ def set_moments(self, moments, moment_type="e"):
def dipten(self):
return np.sum(self.spins[:, None, None] * self._dT, axis=0)

def frequency(self, axis=[0, 0, 1]):
def frequency(self, axis=(0, 0, 1)):

D = self.dipten()
return np.sum(np.dot(D, axis) * axis)
Expand All @@ -173,7 +174,7 @@ def pwd_spec(self, width=None, h_steps=100, nsteps=100):

return om, spec

def random_spec_uniaxial(self, axis=[0, 0, 1], width=None, h_steps=100, occ=1.0):
def random_spec_uniaxial(self, axis=(0, 0, 1), width=None, h_steps=100, occ=1.0):

# Consider individual dipolar constants
DD = self.spins[:, None, None] * self._dT
Expand Down
9 changes: 1 addition & 8 deletions pymuonsuite/io/castep.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


class ReadWriteCastep(ReadWrite):
def __init__(self, params={}, script=None, calc=None):
def __init__(self, params=None, script=None, calc=None):
"""
| params (dict) Contains muon symbol, parameter file,
| k_points_grid.
Expand All @@ -40,13 +40,6 @@ def __init__(self, params={}, script=None, calc=None):
if calc is not None and params != {}:
self._create_calculator()

def _validate_params(self, params):
if not (isinstance(params, dict)):
raise ValueError("params should be a dict, not ", type(params))
return
else:
return params

def set_params(self, params):
"""
| params (dict) Contains muon symbol, parameter file,
Expand Down
8 changes: 2 additions & 6 deletions pymuonsuite/io/dftb.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@


class ReadWriteDFTB(ReadWrite):
def __init__(self, params={}, script=None, calc=None):
def __init__(self, params=None, script=None, calc=None):
"""
| Args:
| params (dict): Contains dftb_set, k_points_grid,
Expand Down Expand Up @@ -70,11 +70,7 @@ def set_params(self, params):
| charged are also required in the case
| of writing geom_opt input files
"""
if not (isinstance(params, dict)):
raise ValueError("params should be a dict, not ", type(params))
return

self.params = deepcopy(params)
self.params = deepcopy(self._validate_params(params))
# resetting this to None makes sure that the calc is recreated after
# the params are updated:
self._calc_type = None
Expand Down
26 changes: 19 additions & 7 deletions pymuonsuite/io/readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class ReadWrite(object):
def __init__(self, params={}, script=None, calc=None):
def __init__(self, params=None, script=None, calc=None):
"""
| params (dict) parameters for writing input files
| script (str): Path to a file containing a submission
Expand All @@ -19,7 +19,23 @@ def __init__(self, params={}, script=None, calc=None):
"""
self._calc = calc
self.script = script
self.params = params
self.params = self._validate_params(params)

def _validate_params(self, params: dict) -> dict:
"""
| Args:
| params (dict): dict of parameters to validate
| Returns:
| (dict): params, or an empty dict if they were None
| Raises:
| TypeError: params is neither None nor a dict
"""
if params is None:
return {}
elif isinstance(params, dict):
return params
else:
raise TypeError(f"params should be a dict, not {type(params)}")

def read(self, folder, sname=None):
raise (
Expand All @@ -40,11 +56,7 @@ def set_params(self, params):
| params (dict) Contains muon symbol, parameter file,
| k_points_grid.
"""
if not (isinstance(params, dict)):
raise ValueError("params should be a dict, not ", type(params))
return
else:
self.params = params
self.params = self._validate_params(params)

def set_script(self, script):
"""
Expand Down
2 changes: 1 addition & 1 deletion pymuonsuite/io/uep.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


class ReadWriteUEP(ReadWrite):
def __init__(self, params={}, script=None):
def __init__(self, params=None, script=None):
self.set_script(script)
self.set_params(params)

Expand Down
4 changes: 2 additions & 2 deletions pymuonsuite/quantum/vibrational/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
def ase_phonon_calc(
struct,
calc=None,
kpoints=[1, 1, 1],
kpoints=(1, 1, 1),
ftol=0.01,
force_clean=False,
name="asephonon",
Expand All @@ -36,7 +36,7 @@ def ase_phonon_calc(
| calc (ase.Calculator): Calculator for energies and forces (if not
| present, use the one from struct)
| kpoints (np.ndarray): Kpoint grid for phonon calculation. If None, just
| do a Vibration modes calculation (default is [1,1,1])
| do a Vibration modes calculation (default is (1,1,1))
| ftol (float): Tolerance for geometry optimisation (default
| is 0.01 eV/Ang)
| force_clean (bool): If True, force a deletion of all phonon files
Expand Down
4 changes: 3 additions & 1 deletion pymuonsuite/quantum/vibrational/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def save(self, file):
def load(file):
return pickle.load(open(file))

def recalc_all(self, displ_args={}, weights_args={}):
def recalc_all(self, displ_args: dict = None, weights_args: dict = None):
displ_args = {} if displ_args is None else displ_args
weights_args = {} if weights_args is None else weights_args
self.recalc_displacements(**displ_args)
self.recalc_weights(**weights_args)

Expand Down

0 comments on commit 00d00fe

Please sign in to comment.