Skip to content

Commit

Permalink
Merge pull request #2517 from davidwaroquiers/input_sets_strcast
Browse files Browse the repository at this point in the history
Fixed MSONable + InputSet's str(object)
  • Loading branch information
shyuep committed May 10, 2022
2 parents 2b75957 + 79e1a87 commit 8a78794
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 6 deletions.
17 changes: 11 additions & 6 deletions pymatgen/io/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,22 @@ def __init__(self, inputs: Dict[Union[str, Path], Union[str, InputFile]] = {}, *
Args:
inputs: The core mapping of filename: file contents that defines the InputSet data.
This should be a dict where keys are filenames and values are InputFile objects
or strings representing the entire contents of the file. This mapping will
or strings representing the entire contents of the file. If a value is not an
InputFile object nor a str, but has a __str__ method, this str representation
of the object will be written to the corresponding file. This mapping will
become the .inputs attribute of the InputSet.
**kwargs: Any kwargs passed will be set as class attributes e.g.
InputSet(inputs={}, foo='bar') will make InputSet.foo == 'bar'.
"""
self.inputs = inputs
self._kwargs = kwargs
self.__dict__.update(**kwargs)

def __getattr__(self, k):
# allow accessing keys as attributes
return self.get(k)
if k in self._kwargs:
return self.get(k)
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{k}'")

def __len__(self):
return len(self.inputs.keys())
Expand Down Expand Up @@ -179,11 +184,11 @@ def write_input(
file.touch()

# write the file
if isinstance(contents, str):
with zopen(file, "wt") as f:
f.write(contents)
else:
if isinstance(contents, InputFile):
contents.write_file(file)
else:
with zopen(file, "wt") as f:
f.write(str(contents))

if zip_inputs:
zipfilename = path / f"{type(self).__name__}.zip"
Expand Down
174 changes: 174 additions & 0 deletions pymatgen/io/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright (c) Pymatgen Development Team.
# Distributed under the terms of the MIT License.

import os

import pytest
from monty.serialization import MontyDecoder
from monty.tempfile import ScratchDir

from pymatgen.core.structure import Structure
from pymatgen.io.cif import CifParser, CifWriter
from pymatgen.io.core import InputFile, InputSet
from pymatgen.util.testing import PymatgenTest

test_dir = os.path.join(PymatgenTest.TEST_FILES_DIR)


class StructInputFile(InputFile):
"""Test implementation of an InputFile object for CIF."""

def __init__(self, structure):
self.structure = structure

def get_string(self) -> str:
cw = CifWriter(self.structure)
return cw.__str__()

@classmethod
def from_string(cls, contents: str):
cp = CifParser.from_string(contents)
struct = cp.get_structures()[0]
return cls(structure=struct)


class FakeClass:
def __init__(self, a, b):
self.a = a
self.b = b

def write_file(self):
raise ValueError

def __str__(self):
return f"{self.a}\n{self.b}"


class TestInputFile:
def test_file_io(self):
with pytest.raises(FileNotFoundError):
StructInputFile.from_file("fakepath.cif")

sif = StructInputFile.from_file(os.path.join(test_dir, "Li.cif"))
assert isinstance(sif.structure, Structure)

with ScratchDir("."):
sif.write_file("newLi.cif")
assert os.path.exists("newLi.cif")

def test_msonable(self):
sif = StructInputFile.from_file(os.path.join(test_dir, "Li.cif"))
sif_dict = sif.as_dict()
decoder = MontyDecoder()
temp_sif = decoder.process_decoded(sif_dict)
assert isinstance(temp_sif, StructInputFile)
assert sif.structure == temp_sif.structure


class TestInputSet:
def test_mapping(self):
sif1 = StructInputFile.from_file(os.path.join(test_dir, "Li.cif"))
sif2 = StructInputFile.from_file(os.path.join(test_dir, "LiFePO4.cif"))
sif3 = StructInputFile.from_file(os.path.join(test_dir, "Li2O.cif"))
inp_set = InputSet(
{
"cif1": sif1,
"cif2": sif2,
"cif3": sif3,
},
kwarg1=1,
kwarg2="hello",
)

assert len(inp_set) == 3
assert inp_set.kwarg1 == 1
assert inp_set.kwarg2 == "hello"
with pytest.raises(AttributeError):
inp_set.kwarg3
expected = [("cif1", sif1), ("cif2", sif2), ("cif3", sif3)]
for (fname, contents), (exp_fname, exp_contents) in zip(inp_set, expected):
assert fname == exp_fname
assert contents is exp_contents

assert inp_set["cif1"] is sif1
with pytest.raises(KeyError):
inp_set["kwarg1"]

sif4 = StructInputFile.from_file(os.path.join(test_dir, "CuCl.cif"))
inp_set["cif4"] = sif4
assert inp_set.inputs["cif4"] is sif4
assert len(inp_set) == 4

del inp_set["cif2"]
del inp_set["cif4"]

assert len(inp_set) == 2
expected = [("cif1", sif1), ("cif3", sif3)]
for (fname, contents), (exp_fname, exp_contents) in zip(inp_set, expected):
assert fname == exp_fname
assert contents is exp_contents

def test_msonable(self):
sif1 = StructInputFile.from_file(os.path.join(test_dir, "Li.cif"))
sif2 = StructInputFile.from_file(os.path.join(test_dir, "Li2O.cif"))
inp_set = InputSet(
{
"cif1": sif1,
"cif2": sif2,
},
kwarg1=1,
kwarg2="hello",
)

inp_set_dict = inp_set.as_dict()
decoder = MontyDecoder()
temp_inp_set = decoder.process_decoded(inp_set_dict)
assert isinstance(temp_inp_set, InputSet)
assert temp_inp_set.kwarg1 == 1
assert temp_inp_set.kwarg2 == "hello"
assert temp_inp_set._kwargs == inp_set._kwargs
for (fname, contents), (fname2, contents2) in zip(temp_inp_set, inp_set):
assert fname == fname2
assert contents.structure == contents2.structure

def test_write(self):
sif1 = StructInputFile.from_file(os.path.join(test_dir, "Li.cif"))
sif2 = StructInputFile.from_file(os.path.join(test_dir, "Li2O.cif"))
inp_set = InputSet(
{
"cif1": sif1,
"cif2": sif2,
},
kwarg1=1,
kwarg2="hello",
)
with ScratchDir("."):
inp_set.write_input(directory="input_dir", make_dir=True, overwrite=True, zip_inputs=False)
assert os.path.exists(os.path.join("input_dir", "cif1"))
assert os.path.exists(os.path.join("input_dir", "cif2"))
assert len(os.listdir("input_dir")) == 2
with pytest.raises(FileExistsError):
inp_set.write_input(directory="input_dir", make_dir=True, overwrite=False, zip_inputs=False)
inp_set.write_input(directory="input_dir", make_dir=True, overwrite=True, zip_inputs=True)
assert len(os.listdir("input_dir")) == 1
assert os.path.exists(os.path.join("input_dir", f"{type(inp_set).__name__}.zip"))
with pytest.raises(FileNotFoundError):
inp_set.write_input(directory="input_dir2", make_dir=False, overwrite=True, zip_inputs=False)

inp_set = InputSet(
{"cif1": sif1, "file_from_str": "hello you", "file_from_strcast": FakeClass(a="Aha", b="Beh")}
)
with ScratchDir("."):
inp_set.write_input(directory="input_dir", make_dir=True, overwrite=True, zip_inputs=False)
assert os.path.exists(os.path.join("input_dir", "cif1"))
assert os.path.exists(os.path.join("input_dir", "file_from_str"))
assert os.path.exists(os.path.join("input_dir", "file_from_strcast"))
assert len(os.listdir("input_dir")) == 3
cp = CifParser(filename=os.path.join("input_dir", "cif1"))
assert cp.get_structures()[0] == sif1.structure
with open(os.path.join("input_dir", "file_from_str")) as f:
file_from_str = f.read()
assert file_from_str == "hello you"
with open(os.path.join("input_dir", "file_from_strcast")) as f:
file_from_strcast = f.read()
assert file_from_strcast == "Aha\nBeh"

0 comments on commit 8a78794

Please sign in to comment.