Skip to content

Commit

Permalink
Merge pull request #40 from larq/make-it-general
Browse files Browse the repository at this point in the history
Make it general
  • Loading branch information
MariaHeuss committed Aug 8, 2019
2 parents 08c0f30 + f2ec5b9 commit f660123
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 71 deletions.
36 changes: 24 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ A small library for managing deep learning models, hyper parameters and datasets
## Getting Started

Zookeeper allows you to build command line interfaces for training deep learning models with very little boiler plate using [click](https://click.palletsprojects.com/) and [TensorFlow Datasets](https://www.tensorflow.org/datasets/). It helps you structure your machine learning projects in a framework agnostic and effective way.
Zookeeper is heavily inspired by [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor) and [Fairseq](https://github.com/pytorch/fairseq/) but is designed to be used as a library making it lightweight and very flexible. Currently zookeeper is limited to image classification tasks but we are working on making it useful for other tasks as well.
Zookeeper is heavily inspired by [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor) and [Fairseq](https://github.com/pytorch/fairseq/) but is designed to be used as a library making it lightweight and very flexible.

### Installation

Expand All @@ -23,16 +23,31 @@ Zookeeper keeps track of data preprocessing, models and hyperparameters to allow
#### Datasets and Preprocessing

TensorFlow Datasets provides [many popular datasets](https://www.tensorflow.org/datasets/datasets) that can be downloaded automatically.
In the following we will use [MNIST](http://yann.lecun.com/exdb/mnist) and define a `default` preprocessing for the images that scales the image to `[0, 1]`:
In the following we will use [MNIST](http://yann.lecun.com/exdb/mnist) and define a `default` preprocessing for the images that scales the image to `[0, 1]` and uses one-hot encoding for the class labels:

```python
import tensorflow as tf
from zookeeper import cli, build_train, HParams, registry, Preprocessing

class ImageClassification(Preprocessing):
@property
def kwargs(self):
return {
"input_shape": self.features["image"].shape,
"num_classes": self.features["label"].num_classes,
}

def inputs(self, data):
return tf.cast(data["image"], tf.float32)

def outputs(self, data):
return tf.one_hot(data["label"], self.features["label"].num_classes)

from zookeeper import cli, build_train, HParams, registry

@registry.register_preprocess("mnist")
def default(image_tensor, training=False):
return tf.cast(image_tensor, dtype=tf.float32) / 255
class default(ImageClassification):
def inputs(self, data):
return super().inputs(data) / 255
```

#### Models
Expand All @@ -41,22 +56,19 @@ Next we will register a model called `cnn`. We will use the [Keras API](https://

```python
@registry.register_model
def cnn(hp, dataset):
def cnn(hp, input_shape, num_classes):
return tf.keras.models.Sequential(
[
tf.keras.layers.Conv2D(
hp.filters[0],
(3, 3),
activation=hp.activation,
input_shape=dataset.input_shape,
hp.filters[0], (3, 3), activation=hp.activation, input_shape=input_shape
),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(hp.filters[1], (3, 3), activation=hp.activation),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(hp.filters[2], (3, 3), activation=hp.activation),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(hp.filters[3], activation=hp.activation),
tf.keras.layers.Dense(dataset.num_classes, activation="softmax"),
tf.keras.layers.Dense(num_classes, activation="softmax"),
]
)
```
Expand Down Expand Up @@ -87,7 +99,7 @@ To train the models registered above we will need to write a custom training loo
@build_train()
def train(build_model, dataset, hparams, output_dir):
"""Start model training."""
model = build_model(hparams, dataset)
model = build_model(hparams, **dataset.preprocessing.kwargs)
model.compile(
optimizer=hparams.optimizer,
loss="categorical_crossentropy",
Expand Down
33 changes: 23 additions & 10 deletions examples/train.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,42 @@
from zookeeper import cli, build_train, registry, HParams
import tensorflow as tf
from zookeeper import cli, build_train, HParams, registry, Preprocessing


class ImageClassification(Preprocessing):
@property
def kwargs(self):
return {
"input_shape": self.features["image"].shape,
"num_classes": self.features["label"].num_classes,
}

def inputs(self, data):
return tf.cast(data["image"], tf.float32)

def outputs(self, data):
return tf.one_hot(data["label"], self.features["label"].num_classes)


@registry.register_preprocess("mnist")
def default(image_tensor):
return tf.cast(image_tensor, dtype=tf.float32) / 255
class default(ImageClassification):
def inputs(self, data):
return super().inputs(data) / 255


@registry.register_model
def cnn(hp, dataset):
def cnn(hp, input_shape, num_classes):
return tf.keras.models.Sequential(
[
tf.keras.layers.Conv2D(
hp.filters[0],
(3, 3),
activation=hp.activation,
input_shape=dataset.input_shape,
hp.filters[0], (3, 3), activation=hp.activation, input_shape=input_shape
),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(hp.filters[1], (3, 3), activation=hp.activation),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(hp.filters[2], (3, 3), activation=hp.activation),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(hp.filters[3], activation=hp.activation),
tf.keras.layers.Dense(dataset.num_classes, activation="softmax"),
tf.keras.layers.Dense(num_classes, activation="softmax"),
]
)

Expand All @@ -49,7 +62,7 @@ class small(basic):
@cli.command()
@build_train()
def train(build_model, dataset, hparams, output_dir):
model = build_model(hparams, dataset)
model = build_model(hparams, **dataset.preprocessing.kwargs)
model.compile(
optimizer=hparams.optimizer,
loss="categorical_crossentropy",
Expand Down
3 changes: 2 additions & 1 deletion zookeeper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from zookeeper.hparam import HParams
from zookeeper.preprocessing import Preprocessing
from zookeeper.cli import build_train, cli

__all__ = ["build_train", "cli", "HParams"]
__all__ = ["build_train", "cli", "HParams", "Preprocessing"]
36 changes: 28 additions & 8 deletions zookeeper/cli_test.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,43 @@
from zookeeper import registry, cli, build_train, HParams, data
from zookeeper import registry, cli, build_train, HParams, data, Preprocessing
from click.testing import CliRunner
import click
from unittest import mock
from os import path
import tensorflow as tf
import tensorflow_datasets as tfds


class ImageClassification(Preprocessing):
@property
def kwargs(self):
return {
"input_shape": self.features["image"].shape,
"num_classes": self.features["label"].num_classes,
}

def inputs(self, data):
return tf.cast(data["image"], tf.float32)

def outputs(self, data):
return tf.one_hot(data["label"], self.features["label"].num_classes)


@registry.register_preprocess("mnist")
def default(image_tensor):
return image_tensor
class default(ImageClassification):
def inputs(self, data):
return super().inputs(data) / 255


@registry.register_preprocess("mnist")
def raw(image_bytes):
return tf.image.decode_image(image_bytes)
class raw(ImageClassification):
decoders = {"image": tfds.decode.SkipDecoding()}

def inputs(self, data):
return tf.cast(tf.image.decode_image(data["image"]), tf.float32)


@registry.register_model
def foo(hparams, dataset):
def foo(hparams, **kwargs):
return "foo-model"


Expand All @@ -36,7 +56,7 @@ def train(build_model, dataset, hparams, output_dir, custom_opt):
assert isinstance(output_dir, str)
assert isinstance(custom_opt, str)

model = build_model(hparams, dataset)
model = build_model(hparams, **dataset.preprocessing.kwargs)
assert model == "foo-model"
assert dataset.dataset_name == "mnist"
assert dataset.train_examples == 60000
Expand All @@ -53,7 +73,7 @@ def train_val(build_model, dataset, hparams, output_dir):
assert isinstance(dataset, data.Dataset)
assert isinstance(output_dir, str)

model = build_model(hparams, dataset)
model = build_model(hparams, **dataset.preprocessing.kwargs)
assert model == "foo-model"
assert dataset.dataset_name == "mnist"
assert dataset.train_examples == 54000
Expand Down
38 changes: 14 additions & 24 deletions zookeeper/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,36 @@
import tensorflow_datasets as tfds


def pass_training_kwarg(function, training=False):
if "training" in inspect.getfullargspec(function).args:
return functools.partial(function, training=training)
return function


class Dataset:
def __init__(
self,
dataset_name,
preprocess_fn,
preprocess_cls,
use_val_split=False,
cache_dir=None,
data_dir=None,
version=None,
):
self.dataset_name = dataset_name
self.preprocess_fn = preprocess_fn
self.data_dir = data_dir
self.cache_dir = cache_dir
self.version = version

dataset_builder = tfds.builder(self.dataset_name_str)
self.info = dataset_builder.info
self.info = tfds.builder(self.dataset_name_str).info
splits = self.info.splits
features = self.info.features
self.preprocessing = preprocess_cls(features=self.info.features)

if tfds.Split.TRAIN not in splits:
raise ValueError("To train we require a train split in the dataset.")
if tfds.Split.TEST not in splits and tfds.Split.VALIDATION not in splits:
raise ValueError("We require a test or validation split in the dataset.")
if not {"image", "label"} <= set(self.info.supervised_keys or []):
raise NotImplementedError("We currently only support image classification")

self.num_classes = features["label"].num_classes
self.input_shape = getattr(
preprocess_fn, "input_shape", features["image"].shape
)
self.train_split = tfds.Split.TRAIN
self.train_examples = splits[self.train_split].num_examples
if tfds.Split.TEST in splits:
Expand Down Expand Up @@ -67,23 +66,14 @@ def load_split(self, split, shuffle=True):
name=self.dataset_name_str,
split=split,
data_dir=self.data_dir,
decoders={"image": tfds.decode.SkipDecoding()},
decoders=self.preprocessing.decoders,
as_dataset_kwargs={"shuffle_files": shuffle},
)

def map_fn(self, data, training=False):
args = inspect.getfullargspec(self.preprocess_fn).args
image_or_bytes = (
data["image"]
if "image_bytes" == args[0]
else self.info.features["image"].decode_example(data["image"])
)
if "training" in args:
image = self.preprocess_fn(image_or_bytes, training=training)
else:
image = self.preprocess_fn(image_or_bytes)

return image, tf.one_hot(data["label"], self.num_classes)
input_fn = pass_training_kwarg(self.preprocessing.inputs, training=training)
output_fn = pass_training_kwarg(self.preprocessing.outputs, training=training)
return input_fn(data), output_fn(data)

def get_cache_path(self, split_name):
if self.cache_dir is None:
Expand Down
12 changes: 10 additions & 2 deletions zookeeper/data_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import pytest
from unittest import mock
from zookeeper import data
from zookeeper import data, Preprocessing


class MockPreprocessing(Preprocessing):
def inputs(self, data):
return data["image"]

def outputs(self, data):
return data["label"]


@mock.patch("os.makedirs")
def test_cache_dir(os_makedirs):
dataset = data.Dataset("mnist", lambda x: x)
dataset = data.Dataset("mnist", MockPreprocessing)
assert dataset.get_cache_path("train") == None

dataset.cache_dir = "memory"
Expand Down
3 changes: 1 addition & 2 deletions zookeeper/data_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ def plot_examples(dataset):

def plot_all_examples(set):
dataset = set.load_split(set.train_split, shuffle=False)
decoder = set.info.features["image"].decode_example

raw = plot_examples(dataset.map(lambda feat: decoder(feat["image"])))
raw = plot_examples(dataset.map(lambda feat: feat["image"]))
train = plot_examples(dataset.map(lambda feat: set.map_fn(feat, training=True)[0]))
eval = plot_examples(dataset.map(lambda feat: set.map_fn(feat)[0]))
return raw, train, eval
53 changes: 53 additions & 0 deletions zookeeper/preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import abc


class Preprocessing(abc.ABC):
"""An abstract class to be used to define data preprocessing.
We define a neural network to be a mapping between an input and an output, hence we
define two abstract methods, an input (e.g the image for image classification), and
an output (e.g the class label for image classification). We also define decoders,
which allows use to customize the decoding, and kwargs, which allows us to pass
information from the preprocessing to the model (e.g input size.)
# Arguments
features: A [`tfds.features.FeaturesDict`](https://www.tensorflow.org/datasets/api_docs/python/tfds/features/FeaturesDict)
# Properties
- `decoders`: Nested `dict` of [`Decoder`](https://www.tensorflow.org/datasets/api_docs/python/tfds/decode/Decoder)
objects which allow to customize the decoding. The structure should match the
feature structure, but only customized feature keys need to be present.
See [the guide](https://www.tensorflow.org/datasets/decode) for more info.
- `kwargs`: A `dict` that can be used to pass additional keyword arguments to a
model function.
"""

decoders = None
kwargs = {}

def __init__(self, features=None):
self.features = features

@abc.abstractmethod
def inputs(self, data, training):
"""A method to define preprocessing for inputs.
This method needs to be overwritten by all subclasses.
# Arguments
data: A dictionary of type {feature_name: Tensor}
training: An optional `boolean` to define if preprocessing is called during training.
"""
raise NotImplementedError("Must be implemented in subclasses.")

@abc.abstractmethod
def outputs(self, data, training):
"""A method to define preprocessing for outputs.
This method needs to be overwritten by all subclasses.
# Arguments
data: A dictionary of type {feature_name: Tensor}
training: An optional `boolean` to define if preprocessing is called during training.
"""
raise NotImplementedError("Must be implemented in subclasses.")
Loading

0 comments on commit f660123

Please sign in to comment.