Skip to content

Commit

Permalink
Re-add the saving of config. Fix config loading.
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan Moldovan committed Feb 17, 2020
1 parent 5fb9a9f commit 6cddad6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 35 deletions.
15 changes: 10 additions & 5 deletions astronet/train.py
Expand Up @@ -29,6 +29,7 @@

from astronet import models
from astronet.ops import dataset_ops
from astronet.util import config_util

parser = argparse.ArgumentParser()

Expand Down Expand Up @@ -80,6 +81,14 @@


def train(model, config):
if FLAGS.model_dir:
dir_name = "{}/{}_{}_{}".format(
FLAGS.model_dir,
FLAGS.model,
FLAGS.config_name,
datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))
config_util.log_and_save_config(config, dir_name)

ds = dataset_ops.build_dataset(
file_pattern=FLAGS.train_files,
input_config=config['inputs'],
Expand Down Expand Up @@ -144,12 +153,8 @@ def train(model, config):
validation_data=eval_ds)

if FLAGS.model_dir:
dir_name = "{}/{}_{}".format(
FLAGS.model_dir,
FLAGS.config_name,
datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))
tf.saved_model.save(model, export_dir=dir_name)
print("Model saved to {}.".format(dir_name))
print("Model saved:\n {}\n".format(dir_name))

return history

Expand Down
36 changes: 6 additions & 30 deletions astronet/util/config_util.py
Expand Up @@ -23,47 +23,23 @@


from absl import logging
from astronet.util import configdict


import tensorflow as tf


def parse_json(json_string_or_file):
"""Parses values from a JSON string or JSON file.
This function is useful for command line flags containing configuration
overrides. Using this function, the flag can be passed either as a JSON string
(e.g. '{"learning_rate": 1.0}') or the path to a JSON configuration file.
def parse_json(json_file):
"""Parses values from a JSON file.
Args:
json_string_or_file: A JSON serialized string OR the path to a JSON file.
json_file: The path to a JSON file.
Returns:
A dictionary; the parsed JSON.
Raises:
ValueError: If the JSON could not be parsed.
"""
# First, attempt to parse the string as a JSON dict.
try:
json_dict = json.loads(json_string_or_file)
except ValueError as literal_json_parsing_error:
try:
# Otherwise, try to use it as a path to a JSON file.
with tf.io.gfile.Open(json_string_or_file) as f:
json_dict = json.load(f)
except ValueError as json_file_parsing_error:
raise ValueError("Unable to parse the content of the json file %s. "
"Parsing error: %s." % (json_string_or_file,
json_file_parsing_error.message))
except tf.io.gfile.FileError:
message = ("Unable to parse the input parameter neither as literal "
"JSON nor as the name of a file that exists.\n"
"JSON parsing error: %s\n\n Input parameter:\n%s." %
(literal_json_parsing_error.message, json_string_or_file))
raise ValueError(message)

return json_dict
with tf.io.gfile.GFile(json_file, 'r') as f:
return configdict.ConfigDict(json.loads(f.read()))


def log_and_save_config(config, output_dir):
Expand Down

0 comments on commit 6cddad6

Please sign in to comment.