Skip to content

Commit

Permalink
Merge pull request #2575 from rkingsbury/inputset_eq
Browse files Browse the repository at this point in the history
InputSet: implement equality method and fix __iter__
  • Loading branch information
mkhorton committed Jul 6, 2022
2 parents 723260b + ab9d4e5 commit 72d6e8b
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 4 deletions.
5 changes: 4 additions & 1 deletion pymatgen/io/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __len__(self):
return len(self.inputs.keys())

def __iter__(self):
return iter(self.inputs.items())
return iter(self.inputs)

def __getitem__(self, key):
return self.inputs[key]
Expand All @@ -155,6 +155,9 @@ def __setitem__(self, key, value):
def __delitem__(self, key):
del self.inputs[key]

def __eq__(self, other):
return (self.inputs == other.inputs) and (self.__dict__ == other.__dict__)

def write_input(
self,
directory: Union[str, Path],
Expand Down
64 changes: 61 additions & 3 deletions pymatgen/io/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def test_mapping(self):
with pytest.raises(AttributeError):
inp_set.kwarg3
expected = [("cif1", sif1), ("cif2", sif2), ("cif3", sif3)]
for (fname, contents), (exp_fname, exp_contents) in zip(inp_set, expected):

for (fname, contents), (exp_fname, exp_contents) in zip(inp_set.items(), expected):
assert fname == exp_fname
assert contents is exp_contents

Expand All @@ -104,10 +105,67 @@ def test_mapping(self):

assert len(inp_set) == 2
expected = [("cif1", sif1), ("cif3", sif3)]
for (fname, contents), (exp_fname, exp_contents) in zip(inp_set, expected):
for (fname, contents), (exp_fname, exp_contents) in zip(inp_set.items(), expected):
assert fname == exp_fname
assert contents is exp_contents

def test_equality(self):
sif1 = StructInputFile.from_file(os.path.join(test_dir, "Li.cif"))
sif2 = StructInputFile.from_file(os.path.join(test_dir, "LiFePO4.cif"))
sif3 = StructInputFile.from_file(os.path.join(test_dir, "Li2O.cif"))

inp_set = InputSet(
{
"cif1": sif1,
"cif2": sif2,
},
kwarg1=1,
kwarg2="hello",
)

inp_set2 = InputSet(
{
"cif1": sif1,
"cif2": sif2,
},
kwarg1=1,
kwarg2="hello",
)

inp_set3 = InputSet(
{
"cif1": sif1,
"cif2": sif2,
"cif3": sif3,
},
kwarg1=1,
kwarg2="hello",
)

inp_set4 = InputSet(
{
"cif1": sif1,
"cif2": sif2,
},
kwarg1=1,
kwarg2="goodbye",
)

inp_set5 = InputSet(
{
"cif1": sif1,
"cif2": sif2,
},
kwarg1=1,
kwarg2="hello",
kwarg3="goodbye",
)

assert inp_set == inp_set2
assert inp_set != inp_set3
assert inp_set != inp_set4
assert inp_set != inp_set5

def test_msonable(self):
sif1 = StructInputFile.from_file(os.path.join(test_dir, "Li.cif"))
sif2 = StructInputFile.from_file(os.path.join(test_dir, "Li2O.cif"))
Expand All @@ -127,7 +185,7 @@ def test_msonable(self):
assert temp_inp_set.kwarg1 == 1
assert temp_inp_set.kwarg2 == "hello"
assert temp_inp_set._kwargs == inp_set._kwargs
for (fname, contents), (fname2, contents2) in zip(temp_inp_set, inp_set):
for (fname, contents), (fname2, contents2) in zip(temp_inp_set.items(), inp_set.items()):
assert fname == fname2
assert contents.structure == contents2.structure

Expand Down

0 comments on commit 72d6e8b

Please sign in to comment.