In [4]:
# import utils
import data
# import ingestor
# import extractor2

In [6]:
def generic_parser(serialized_example, feature_list, label_list):
  """Parses a HOL example, keeping requested features and labels.

  Args:
    serialized_example: A tf.Example for a parameterized tactic application.
    feature_list: List of string feature names to parse (subset of features).
    label_list: List of string label names to parse (subset of labels).

  Returns:
    features, labels: dicts with keys of feature_list, label_list respectively.
  """
  example = tf.parse_single_example(
      serialized_example,
      features={
          # Subgoal features
          # goal: the consequent term of the subgoal as a string.
          'goal': tf.FixedLenFeature((), tf.string, default_value=''),
          # goal_asl: list of hypotheses of the subgoal.
          'goal_asl': tf.VarLenFeature(dtype=tf.string),
          # Parameterized tactic applied to the subgoal
          # tactic: string name of tactic that is applied to this subgoal.
          'tactic': tf.FixedLenFeature((), tf.string, default_value=''),
          # tac_id: integer id of tactic.
          'tac_id': tf.FixedLenFeature((), tf.int64, default_value=-1),
          # thms: list of tactic arguments of type thm.
          'thms': tf.VarLenFeature(dtype=tf.string),
          # thms_hard_negatives: list of hard negative theorem parameter
          # arguments
          'thms_hard_negatives': tf.VarLenFeature(dtype=tf.string),
      })

  for key in ('goal_asl', 'thms', 'thms_hard_negatives'):
    if key in example:
      example[key] = tf.sparse_tensor_to_dense(example[key], default_value='')

  features = {key: example[key] for key in feature_list}
  labels = {key: example[key] for key in label_list}
  return features, labels


In [10]:
def tristan_parser(serialized_example, source, params):
  del source  # unused

  feature_list = ['goal', 'thms', 'thms_hard_negatives']
  label_list = ['tac_id']
  features, labels = generic_parser(
      serialized_example, feature_list=feature_list, label_list=label_list)
  
  # thms: flatten
  size_of_thms = tf.size(features['thms'])
  print(size_of_thms) 
#   print(tf.shape(tf.cond(size_of_thms > 0, lambda: tf.reshape(features['thms'], [-1,size_of_thms]), lambda: '')))
  features['thms'] = _choose_one_theorem_at_random(features['thms'])

  # thms_hard_negatives: Shuffle, truncate and then pad with '<NULL>'.
  features['thms_hard_negatives'] = _shuffle_and_truncate_hard_negatives(
      features['thms_hard_negatives'], params)

  return features, labels

In [15]:
def get_input_fn(dataset_fn,
                 mode,
                 params,
                 shuffle=None,
                 shuffle_queue=None,
                 repeat=None,
                 parser=None,
                 filt=None):
    
  if shuffle_queue is None:
    shuffle_queue = params.shuffle_queue
  if shuffle is None:
    shuffle = mode == TRAIN
  if repeat is None:
    do_repeat = mode == TRAIN
  elif not repeat:
    repeat = None
    do_repeat = False

  if parser is None:
    tf.logging.info('PASSED IN parser is None')
    parser = tristan_parser #pairwise_thm_parser

  def input_fn():
    """Input Function for estimator."""
    ds = dataset_fn(params)
    if params.setdefault('cache', False):
      ds = ds.cache()
    if repeat is not None:
      ds = ds.repeat(repeat)
    elif do_repeat:
      ds = ds.repeat()
    if shuffle:
      ds = ds.shuffle(shuffle_queue)

    ds = ds.map(functools.partial(parser, params=params))

    if filt is not None:
      ds = ds.filter(filt)

    drop = mode == EVAL

    ds = ds.batch(params['batch_size'], drop_remainder=drop)
    
    
    return ds.make_one_shot_iterator().get_next()

  return input_fn