Skip to content

Commit

Permalink
Simplify passing trainable
Browse files Browse the repository at this point in the history
  • Loading branch information
lgeiger committed Aug 7, 2019
1 parent d84cf96 commit cd42a25
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions zookeeper/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
import tensorflow_datasets as tfds


def pass_training_kwarg(function, training=False):
if "training" in inspect.getfullargspec(function).args:
return functools.partial(function, training=training)
return function


class Dataset:
def __init__(
self,
Expand Down Expand Up @@ -64,17 +70,10 @@ def load_split(self, split, shuffle=True):
as_dataset_kwargs={"shuffle_files": shuffle},
)

def _call_prepro(self, preprocess_fn, data, training=False):
if "training" in inspect.getfullargspec(preprocess_fn).args:
return preprocess_fn(data, training=training)
else:
return preprocess_fn(data)

def map_fn(self, data, training=False):
return (
self._call_prepro(self.preprocessing.inputs, data, training=training),
self._call_prepro(self.preprocessing.outputs, data, training=training),
)
input_fn = pass_training_kwarg(self.preprocessing.inputs, training=training)
output_fn = pass_training_kwarg(self.preprocessing.outputs, training=training)
return input_fn(data), output_fn(data)

def get_cache_path(self, split_name):
if self.cache_dir is None:
Expand Down

0 comments on commit cd42a25

Please sign in to comment.