Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add file arguments #26

Merged
merged 1 commit into from
Mar 3, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 21 additions & 14 deletions f90nml/namelist.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,22 +223,30 @@ def false_repr(self, value):
def write(self, nml_path, force=False):
"""Output dict to a Fortran 90 namelist file."""

if not force and os.path.isfile(nml_path):
nml_is_file = hasattr(nml_path, 'read')
if not force and not nml_is_file and os.path.isfile(nml_path):
raise IOError('File {0} already exists.'.format(nml_path))

with open(nml_path, 'w') as nml_file:
for grp_name, grp_vars in self.items():
# Check for repeated namelist records (saved as lists)
if isinstance(grp_vars, list):
for g_vars in grp_vars:
self.write_nmlgrp(grp_name, g_vars, nml_file)
else:
self.write_nmlgrp(grp_name, grp_vars, nml_file)
# collect individual (name, namelist) pairs, including repeated names
groups = []
for (name,vals) in self.items():
if isinstance(vals, list): # repeated namelist
groups.extend((name, v) for v in vals)
else:
groups.append((name, vals))

nml_file = nml_path if nml_is_file else open(nml_path, 'w')
try:
if len(groups) > 0:
first_name, first_vals = groups[0]
self.write_nmlgrp(first_name, first_vals, nml_file)

if self.items():
with open(nml_path, 'rb+') as nml_file:
nml_file.seek(-1, os.SEEK_END)
nml_file.truncate()
for (name, vals) in groups[1:]:
print(file=nml_file) # double-space between groups
self.write_nmlgrp(name, vals, nml_file)
finally:
if not nml_is_file:
nml_file.close()

def write_nmlgrp(self, grp_name, grp_vars, nml_file):
"""Write namelist group to target file."""
Expand All @@ -255,7 +263,6 @@ def write_nmlgrp(self, grp_name, grp_vars, nml_file):
print(nml_line, file=nml_file)

print('/', file=nml_file)
print(file=nml_file)

def var_strings(self, v_name, v_values, v_idx=None):
"""Convert namelist variable to list of fixed-width strings."""
Expand Down
23 changes: 16 additions & 7 deletions f90nml/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,30 +75,39 @@ def read(self, nml_fname, nml_patch_in=None, patch_fname=None):
>>> parser = Parser()
>>> data_nml = parser.read('data.nml')"""

# For switching based on files versus paths
nml_is_path = not hasattr(nml_fname, 'read')
patch_is_path = not hasattr(patch_fname, 'read')

# Convert patch data to a Namelist object
if nml_patch_in:
if not isinstance(nml_patch_in, dict):
raise ValueError('Input patch must be a dict or a Namelist.')

nml_patch = copy.deepcopy(Namelist(nml_patch_in))

if not patch_fname:
if not patch_fname and nml_is_path:
patch_fname = nml_fname + '~'
elif not patch_fname:
raise ValueError('f90nml: error: No output file for patch.')
elif nml_fname == patch_fname:
raise ValueError('f90nml: error: Patch filepath cannot be the '
'same as the original filepath.')
self.pfile = open(patch_fname, 'w')
self.pfile = open(patch_fname, 'w') if patch_is_path else patch_fname
else:
nml_patch = Namelist()

try:
nml_file = open(nml_fname, 'r')
return self.readstream(nml_file, nml_patch)
nml_file = open(nml_fname, 'r') if nml_is_path else nml_fname
try:
return self.readstream(nml_file, nml_patch)

# Close the files we opened on any exceptions within readstream
finally:
if nml_is_path:
nml_file.close()
finally:
# Close the unfinished files on any exceptions within readstream
nml_file.close()
if self.pfile:
if self.pfile and patch_is_path:
self.pfile.close()

def readstream(self, nml_file, nml_patch):
Expand Down
33 changes: 32 additions & 1 deletion test/test_f90nml.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,27 @@ def assert_file_equal(self, source_fname, target_fname):
self.assertEqual(source_str, target_str)

def assert_write(self, nml, target_fname):
self.assert_write_path(nml, target_fname)
self.assert_write_file(nml, target_fname)

def assert_write_path(self, nml, target_fname):
tmp_fname = 'tmp.nml'
f90nml.write(nml, tmp_fname)
try:
self.assert_file_equal(tmp_fname, target_fname)
finally:
os.remove(tmp_fname)

def assert_write_file(self, nml, target_fname):
tmp_fname = 'tmp.nml'
with open(tmp_fname, 'w') as tmp_file:
f90nml.write(nml, tmp_file)
self.assertFalse(tmp_file.closed)
try:
self.assert_file_equal(tmp_fname, target_fname)
finally:
os.remove(tmp_fname)

# Tests
def test_empty(self):
test_nml = f90nml.read('empty.nml')
Expand Down Expand Up @@ -352,7 +366,7 @@ def test_pop_key(self):
test_nml.pop('empty_nml')
self.assertEqual(test_nml, f90nml.namelist.Namelist())

def test_patch(self):
def test_patch_paths(self):
patch_nml = f90nml.read('types_patch.nml')
f90nml.patch('types.nml', patch_nml, 'tmp.nml')
test_nml = f90nml.read('tmp.nml')
Expand All @@ -361,6 +375,19 @@ def test_patch(self):
finally:
os.remove('tmp.nml')

def test_patch_files(self):
patch_nml = f90nml.read('types_patch.nml')
with open('types.nml') as f_in:
with open('tmp.nml', 'w') as f_out:
f90nml.patch(f_in, patch_nml, f_out)
self.assertFalse(f_in.closed)
self.assertFalse(f_out.closed)
try:
test_nml = f90nml.read('tmp.nml')
self.assertEqual(test_nml, patch_nml)
finally:
os.remove('tmp.nml')

def test_patch_case(self):
patch_nml = f90nml.read('types_patch.nml')
f90nml.patch('types_uppercase.nml', patch_nml, 'tmp.nml')
Expand All @@ -383,6 +410,10 @@ def test_default_patch(self):
finally:
os.remove('types.nml~')

# The above behavior is only for paths, not files
with open('types.nml') as nml_file:
self.assertRaises(ValueError, f90nml.patch, nml_file, patch_nml)

def test_no_selfpatch(self):
patch_nml = f90nml.read('types_patch.nml')
self.assertRaises(ValueError, f90nml.patch,
Expand Down