Skip to content

Commit

Permalink
Fix numpy arrays in JsonParser.generate, add format_content utili…
Browse files Browse the repository at this point in the history
…ty (#739)

* fix json serialization of numpy arr

* add utility function to generate pretty config from a cfg node

* rename to `format_content`
  • Loading branch information
Helveg committed Jul 24, 2023
1 parent e61bfaa commit bde9795
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
12 changes: 12 additions & 0 deletions bsb/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ def copy_template(self, template, output="network_configuration.json", path=None
copy_file(files[0], output)

def from_file(self, file):
"""
Create a configuration object from a path or file-like object.
"""
if not hasattr(file, "read"):
with open(file, "r") as f:
return self.from_file(f)
Expand All @@ -149,10 +152,19 @@ def from_file(self, file):
return self.from_content(file.read(), path)

def from_content(self, content, path=None):
"""
Create a configuration object from a content string
"""
ext = path.split(".")[-1] if path is not None else None
parser, tree, meta = _try_parsers(content, self._parser_classes, ext, path=path)
return _from_parsed(self, parser, tree, meta, path)

def format_content(self, parser_name, config):
"""
Convert a configuration object to a string using the given parser.
"""
return self.get_parser(parser_name).generate(config.__tree__(), pretty=True)

__all__ = [*(vars().keys() - {"__init__", "__qualname__", "__module__"})]


Expand Down
12 changes: 10 additions & 2 deletions bsb/config/parsers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import json
import numpy as np
import os
from ...exceptions import JsonImportError, ConfigurationWarning, JsonReferenceError
from ...reporting import warn
Expand Down Expand Up @@ -142,6 +143,13 @@ def resolve(self, parser, target):
self.node[key] = target[key]


def _to_json(value):
if isinstance(value, np.ndarray):
return value.tolist()
else:
raise TypeError()


class JsonParser(Parser):
"""
Parser plugin class to parse JSON configuration files.
Expand Down Expand Up @@ -173,9 +181,9 @@ def parse(self, content, path=None):

def generate(self, tree, pretty=False):
if pretty:
return json.dumps(tree, indent=4)
return json.dumps(tree, indent=4, default=_to_json)
else:
return json.dumps(tree)
return json.dumps(tree, default=_to_json)

def _traverse(self, node, iter):
# Iterates over all values in `iter` and checks for import keys, recursion or refs
Expand Down

0 comments on commit bde9795

Please sign in to comment.