Skip to content

Commit

Permalink
Merge pull request #26 from ExpHP/file-arguments
Browse files Browse the repository at this point in the history
Add file arguments
  • Loading branch information
marshallward committed Mar 3, 2016
2 parents dd7df09 + 47a2530 commit 8054312
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 22 deletions.
35 changes: 21 additions & 14 deletions f90nml/namelist.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,22 +223,30 @@ def false_repr(self, value):
def write(self, nml_path, force=False):
"""Output dict to a Fortran 90 namelist file."""

if not force and os.path.isfile(nml_path):
nml_is_file = hasattr(nml_path, 'read')
if not force and not nml_is_file and os.path.isfile(nml_path):
raise IOError('File {0} already exists.'.format(nml_path))

with open(nml_path, 'w') as nml_file:
for grp_name, grp_vars in self.items():
# Check for repeated namelist records (saved as lists)
if isinstance(grp_vars, list):
for g_vars in grp_vars:
self.write_nmlgrp(grp_name, g_vars, nml_file)
else:
self.write_nmlgrp(grp_name, grp_vars, nml_file)
# collect individual (name, namelist) pairs, including repeated names
groups = []
for (name,vals) in self.items():
if isinstance(vals, list): # repeated namelist
groups.extend((name, v) for v in vals)
else:
groups.append((name, vals))

nml_file = nml_path if nml_is_file else open(nml_path, 'w')
try:
if len(groups) > 0:
first_name, first_vals = groups[0]
self.write_nmlgrp(first_name, first_vals, nml_file)

if self.items():
with open(nml_path, 'rb+') as nml_file:
nml_file.seek(-1, os.SEEK_END)
nml_file.truncate()
for (name, vals) in groups[1:]:
print(file=nml_file) # double-space between groups
self.write_nmlgrp(name, vals, nml_file)
finally:
if not nml_is_file:
nml_file.close()

def write_nmlgrp(self, grp_name, grp_vars, nml_file):
"""Write namelist group to target file."""
Expand All @@ -255,7 +263,6 @@ def write_nmlgrp(self, grp_name, grp_vars, nml_file):
print(nml_line, file=nml_file)

print('/', file=nml_file)
print(file=nml_file)

def var_strings(self, v_name, v_values, v_idx=None):
"""Convert namelist variable to list of fixed-width strings."""
Expand Down
23 changes: 16 additions & 7 deletions f90nml/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,30 +75,39 @@ def read(self, nml_fname, nml_patch_in=None, patch_fname=None):
>>> parser = Parser()
>>> data_nml = parser.read('data.nml')"""

# For switching based on files versus paths
nml_is_path = not hasattr(nml_fname, 'read')
patch_is_path = not hasattr(patch_fname, 'read')

# Convert patch data to a Namelist object
if nml_patch_in:
if not isinstance(nml_patch_in, dict):
raise ValueError('Input patch must be a dict or a Namelist.')

nml_patch = copy.deepcopy(Namelist(nml_patch_in))

if not patch_fname:
if not patch_fname and nml_is_path:
patch_fname = nml_fname + '~'
elif not patch_fname:
raise ValueError('f90nml: error: No output file for patch.')
elif nml_fname == patch_fname:
raise ValueError('f90nml: error: Patch filepath cannot be the '
'same as the original filepath.')
self.pfile = open(patch_fname, 'w')
self.pfile = open(patch_fname, 'w') if patch_is_path else patch_fname
else:
nml_patch = Namelist()

try:
nml_file = open(nml_fname, 'r')
return self.readstream(nml_file, nml_patch)
nml_file = open(nml_fname, 'r') if nml_is_path else nml_fname
try:
return self.readstream(nml_file, nml_patch)

# Close the files we opened on any exceptions within readstream
finally:
if nml_is_path:
nml_file.close()
finally:
# Close the unfinished files on any exceptions within readstream
nml_file.close()
if self.pfile:
if self.pfile and patch_is_path:
self.pfile.close()

def readstream(self, nml_file, nml_patch):
Expand Down
33 changes: 32 additions & 1 deletion test/test_f90nml.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,27 @@ def assert_file_equal(self, source_fname, target_fname):
self.assertEqual(source_str, target_str)

def assert_write(self, nml, target_fname):
self.assert_write_path(nml, target_fname)
self.assert_write_file(nml, target_fname)

def assert_write_path(self, nml, target_fname):
tmp_fname = 'tmp.nml'
f90nml.write(nml, tmp_fname)
try:
self.assert_file_equal(tmp_fname, target_fname)
finally:
os.remove(tmp_fname)

def assert_write_file(self, nml, target_fname):
tmp_fname = 'tmp.nml'
with open(tmp_fname, 'w') as tmp_file:
f90nml.write(nml, tmp_file)
self.assertFalse(tmp_file.closed)
try:
self.assert_file_equal(tmp_fname, target_fname)
finally:
os.remove(tmp_fname)

# Tests
def test_empty(self):
test_nml = f90nml.read('empty.nml')
Expand Down Expand Up @@ -352,7 +366,7 @@ def test_pop_key(self):
test_nml.pop('empty_nml')
self.assertEqual(test_nml, f90nml.namelist.Namelist())

def test_patch(self):
def test_patch_paths(self):
patch_nml = f90nml.read('types_patch.nml')
f90nml.patch('types.nml', patch_nml, 'tmp.nml')
test_nml = f90nml.read('tmp.nml')
Expand All @@ -361,6 +375,19 @@ def test_patch(self):
finally:
os.remove('tmp.nml')

def test_patch_files(self):
patch_nml = f90nml.read('types_patch.nml')
with open('types.nml') as f_in:
with open('tmp.nml', 'w') as f_out:
f90nml.patch(f_in, patch_nml, f_out)
self.assertFalse(f_in.closed)
self.assertFalse(f_out.closed)
try:
test_nml = f90nml.read('tmp.nml')
self.assertEqual(test_nml, patch_nml)
finally:
os.remove('tmp.nml')

def test_patch_case(self):
patch_nml = f90nml.read('types_patch.nml')
f90nml.patch('types_uppercase.nml', patch_nml, 'tmp.nml')
Expand All @@ -383,6 +410,10 @@ def test_default_patch(self):
finally:
os.remove('types.nml~')

# The above behavior is only for paths, not files
with open('types.nml') as nml_file:
self.assertRaises(ValueError, f90nml.patch, nml_file, patch_nml)

def test_no_selfpatch(self):
patch_nml = f90nml.read('types_patch.nml')
self.assertRaises(ValueError, f90nml.patch,
Expand Down

0 comments on commit 8054312

Please sign in to comment.