Skip to content

Commit

Permalink
Add support for arbitrary tasks
Browse files Browse the repository at this point in the history
Co-Authored-By: James Widdicombe <j.y.widdicombe@gmail.com>
  • Loading branch information
lgeiger and jamescook106 committed Aug 7, 2019
1 parent 08c0f30 commit d84cf96
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 48 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ A small library for managing deep learning models, hyper parameters and datasets
## Getting Started

Zookeeper allows you to build command line interfaces for training deep learning models with very little boiler plate using [click](https://click.palletsprojects.com/) and [TensorFlow Datasets](https://www.tensorflow.org/datasets/). It helps you structure your machine learning projects in a framework agnostic and effective way.
Zookeeper is heavily inspired by [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor) and [Fairseq](https://github.com/pytorch/fairseq/) but is designed to be used as a library making it lightweight and very flexible. Currently zookeeper is limited to image classification tasks but we are working on making it useful for other tasks as well.
Zookeeper is heavily inspired by [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor) and [Fairseq](https://github.com/pytorch/fairseq/) but is designed to be used as a library making it lightweight and very flexible.

### Installation

Expand All @@ -23,7 +23,7 @@ Zookeeper keeps track of data preprocessing, models and hyperparameters to allow
#### Datasets and Preprocessing

TensorFlow Datasets provides [many popular datasets](https://www.tensorflow.org/datasets/datasets) that can be downloaded automatically.
In the following we will use [MNIST](http://yann.lecun.com/exdb/mnist) and define a `default` preprocessing for the images that scales the image to `[0, 1]`:
In the following we will use [MNIST](http://yann.lecun.com/exdb/mnist) and define a `default` preprocessing for the images that scales the image to `[0, 1]` and uses one-hot encoding for the class labels:

```python
import tensorflow as tf
Expand Down
3 changes: 2 additions & 1 deletion zookeeper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from zookeeper.hparam import HParams
from zookeeper.preprocessing import Preprocessing
from zookeeper.cli import build_train, cli

__all__ = ["build_train", "cli", "HParams"]
__all__ = ["build_train", "cli", "HParams", "Preprocessing"]
35 changes: 27 additions & 8 deletions zookeeper/cli_test.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,42 @@
from zookeeper import registry, cli, build_train, HParams, data
from zookeeper import registry, cli, build_train, HParams, data, Preprocessing
from click.testing import CliRunner
import click
from unittest import mock
from os import path
import tensorflow as tf
import tensorflow_datasets as tfds


class ImageClassification(Preprocessing):
@property
def kwargs(self):
return {
"input_shape": self.features["image"].shape,
"num_classes": self.features["label"].num_classes,
}

def inputs(self, data):
return tf.cast(data["image"], tf.float32)

def outputs(self, data):
return tf.one_hot(data["label"], self.features["label"].num_classes)


@registry.register_preprocess("mnist")
def default(image_tensor):
return image_tensor
class default(ImageClassification):
pass


@registry.register_preprocess("mnist")
def raw(image_bytes):
return tf.image.decode_image(image_bytes)
class raw(ImageClassification):
decoders = {"image": tfds.decode.SkipDecoding()}

def inputs(self, data):
return tf.cast(tf.image.decode_image(data["image"]), tf.float32)


@registry.register_model
def foo(hparams, dataset):
def foo(hparams, **kwargs):
return "foo-model"


Expand All @@ -36,7 +55,7 @@ def train(build_model, dataset, hparams, output_dir, custom_opt):
assert isinstance(output_dir, str)
assert isinstance(custom_opt, str)

model = build_model(hparams, dataset)
model = build_model(hparams, **dataset.preprocessing.kwargs)
assert model == "foo-model"
assert dataset.dataset_name == "mnist"
assert dataset.train_examples == 60000
Expand All @@ -53,7 +72,7 @@ def train_val(build_model, dataset, hparams, output_dir):
assert isinstance(dataset, data.Dataset)
assert isinstance(output_dir, str)

model = build_model(hparams, dataset)
model = build_model(hparams, **dataset.preprocessing.kwargs)
assert model == "foo-model"
assert dataset.dataset_name == "mnist"
assert dataset.train_examples == 54000
Expand Down
37 changes: 14 additions & 23 deletions zookeeper/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,26 @@ class Dataset:
def __init__(
self,
dataset_name,
preprocess_fn,
preprocess_cls,
use_val_split=False,
cache_dir=None,
data_dir=None,
version=None,
):
self.dataset_name = dataset_name
self.preprocess_fn = preprocess_fn
self.data_dir = data_dir
self.cache_dir = cache_dir
self.version = version

dataset_builder = tfds.builder(self.dataset_name_str)
self.info = dataset_builder.info
self.info = tfds.builder(self.dataset_name_str).info
splits = self.info.splits
features = self.info.features
self.preprocessing = preprocess_cls(features=self.info.features)

if tfds.Split.TRAIN not in splits:
raise ValueError("To train we require a train split in the dataset.")
if tfds.Split.TEST not in splits and tfds.Split.VALIDATION not in splits:
raise ValueError("We require a test or validation split in the dataset.")
if not {"image", "label"} <= set(self.info.supervised_keys or []):
raise NotImplementedError("We currently only support image classification")

self.num_classes = features["label"].num_classes
self.input_shape = getattr(
preprocess_fn, "input_shape", features["image"].shape
)
self.train_split = tfds.Split.TRAIN
self.train_examples = splits[self.train_split].num_examples
if tfds.Split.TEST in splits:
Expand Down Expand Up @@ -67,23 +60,21 @@ def load_split(self, split, shuffle=True):
name=self.dataset_name_str,
split=split,
data_dir=self.data_dir,
decoders={"image": tfds.decode.SkipDecoding()},
decoders=self.preprocessing.decoders,
as_dataset_kwargs={"shuffle_files": shuffle},
)

def map_fn(self, data, training=False):
args = inspect.getfullargspec(self.preprocess_fn).args
image_or_bytes = (
data["image"]
if "image_bytes" == args[0]
else self.info.features["image"].decode_example(data["image"])
)
if "training" in args:
image = self.preprocess_fn(image_or_bytes, training=training)
def _call_prepro(self, preprocess_fn, data, training=False):
if "training" in inspect.getfullargspec(preprocess_fn).args:
return preprocess_fn(data, training=training)
else:
image = self.preprocess_fn(image_or_bytes)
return preprocess_fn(data)

return image, tf.one_hot(data["label"], self.num_classes)
def map_fn(self, data, training=False):
return (
self._call_prepro(self.preprocessing.inputs, data, training=training),
self._call_prepro(self.preprocessing.outputs, data, training=training),
)

def get_cache_path(self, split_name):
if self.cache_dir is None:
Expand Down
3 changes: 1 addition & 2 deletions zookeeper/data_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ def plot_examples(dataset):

def plot_all_examples(set):
dataset = set.load_split(set.train_split, shuffle=False)
decoder = set.info.features["image"].decode_example

raw = plot_examples(dataset.map(lambda feat: decoder(feat["image"])))
raw = plot_examples(dataset.map(lambda feat: feat["image"]))
train = plot_examples(dataset.map(lambda feat: set.map_fn(feat, training=True)[0]))
eval = plot_examples(dataset.map(lambda feat: set.map_fn(feat)[0]))
return raw, train, eval
17 changes: 17 additions & 0 deletions zookeeper/preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import abc


class Preprocessing(abc.ABC):
decoders = None
kwargs = {}

def __init__(self, features=None):
self.features = features

@abc.abstractmethod
def inputs(self, data, training):
raise NotImplementedError("Must be implemented in subclasses.")

@abc.abstractmethod
def outputs(self, data, training):
raise NotImplementedError("Must be implemented in subclasses.")
23 changes: 11 additions & 12 deletions zookeeper/registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import tensorflow_datasets as tfds
from zookeeper.data import Dataset
from zookeeper.preprocessing import Preprocessing

MODEL_REGISTRY = {}
HPARAMS_REGISTRY = {}
Expand Down Expand Up @@ -44,24 +45,22 @@ def __init__(self, model_name, name):
ValueError.__init__(self, err)


def register_preprocess(dataset_name, input_shape=None):
def register_preprocess_fn(fn):
if not callable(fn):
raise ValueError("Preprocess function must be callable")
name = fn.__name__
def register_preprocess(dataset_name):
def register_preprocess_cls(cls):
if not issubclass(cls, Preprocessing):
raise ValueError("Preprocess must be a subclass of zookeeper.Preprocessing")
name = cls.__name__
if dataset_name not in DATA_REGISTRY:
raise DatasetNotFoundError(dataset_name)
data_preprocess_fns = DATA_REGISTRY[dataset_name]
if name in data_preprocess_fns:
data_preprocess_clss = DATA_REGISTRY[dataset_name]
if name in data_preprocess_clss:
raise ValueError(
f"Cannot register duplicate preprocessing ({name}) for dataset ({dataset_name})"
)
if input_shape:
setattr(fn, "input_shape", input_shape)
data_preprocess_fns[name] = fn
return fn
data_preprocess_clss[name] = cls
return cls

return register_preprocess_fn
return register_preprocess_cls


def register_model(model):
Expand Down

0 comments on commit d84cf96

Please sign in to comment.