Skip to content

Commit

Permalink
Some updates to jax2tf
Browse files Browse the repository at this point in the history
  • Loading branch information
marcvanzee committed Apr 12, 2021
1 parent 5ed0633 commit 9e3001b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 18 deletions.
22 changes: 13 additions & 9 deletions jax/experimental/jax2tf/examples/tf_js/quickdraw/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,28 @@ and is a part of the code featured in the aforementioned blog post.
## Training the model

You can train and export the model yourself by running

```bash
$ python3 quickdraw.py
```

The training itself lasts roughly 30 minutes, assuming the dataset has already
been downloaded. You can also skip to the [next section](#interacting-with-the-model)
to see instructions for playing with a pre-trained model.
This will first download the dataset if it is not downloaded yet, which is about
11Gb in total. Assuming the dataset has been downloaded already, training for 5
epochs takes roughly 10 minutes on a CPU (3,5 GHz Dual-Core Intel Core i7). You
can also skip to the [next section](#interacting-with-the-model) to see
instructions for playing with a pre-trained model.

The dataset will be downloaded directly into a `data/` directory in the
`/tmp/jax2tf/tf_js_quickdraw` directory; by default, the model is configured to
classify inputs into 100 different classes, which corresponds to a dataset of
roughly 11 Gb. This can be tweaked by modifying the value of the `NB_CLASSES`
global variable in `quickdraw.py` (max number of classes: 100). Only the files
corresponding to the first `NB_CLASSES` classes in the dataset will be
classify inputs into 100 different classes. This can be tweaked using the
command-line argument `--num_classes` (max number of classes: 100). Only the
files corresponding to the first `num_classes` classes in the dataset will be
downloaded.

The training loop runs for 5 epochs, and the model as well as its equivalent
TF.js-loadable model are subsequently saved into `/tmp/jax2tf/tf_js_quickdraw`.
The training loop runs for 5 epochs by default (this can be changed using
the command-line argument `--num_epochs`), and the model as well as its
equivalent TF.js-loadable model are subsequently saved into
`/tmp/jax2tf/tf_js_quickdraw`.

## Interacting with the model

Expand Down
37 changes: 28 additions & 9 deletions jax/experimental/jax2tf/examples/tf_js/quickdraw/quickdraw.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from absl import app # type: ignore
from absl import flags

import os # type: ignore
import time
from typing import Callable
Expand All @@ -35,7 +37,18 @@

import utils

NB_CLASSES = 100
flags.DEFINE_boolean("run_eval_on_train", False,
("Also run eval on the train set after each epoch. This "
"slows down training considerably."))
flags.DEFINE_integer("num_epochs", 5,
("Number of epochs to train for."))
flags.DEFINE_integer("num_classes", 100, "Number of classification classes.")

flags.register_validator('num_classes',
lambda value: value >= 1 and value <= 100,
message='--num_classes must be in range [1, 100]')

FLAGS = flags.FLAGS

# The code below is an adaptation for Flax from the work published here:
# https://blog.tensorflow.org/2018/07/train-model-in-tfkeras-with-colab-and-run-in-browser-tensorflowjs.html
Expand All @@ -59,7 +72,7 @@ def __call__(self, x):
x = nn.Dense(features=128)(x)
x = nn.relu(x)

x = nn.Dense(features=NB_CLASSES)(x)
x = nn.Dense(features=FLAGS.num_classes)(x)
x = nn.softmax(x)

return x
Expand Down Expand Up @@ -110,25 +123,31 @@ def init_model():

def train(train_ds, test_ds, classes):
optimizer, params = init_model()
for epoch in range(5):
for epoch in range(1, FLAGS.num_epochs+1):
start_time = time.time()
optimizer = train_one_epoch(optimizer, train_ds)
epoch_time = time.time() - start_time
train_acc = accuracy(predict, optimizer.target, train_ds)

if FLAGS.run_eval_on_train:
train_acc = accuracy(predict, optimizer.target, train_ds)
print("Training set accuracy {}".format(train_acc))

test_acc = accuracy(predict, optimizer.target, test_ds)
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
epoch_time = time.time() - start_time
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))

return optimizer.target

def main(*args):
base_model_path = "/tmp/jax2tf/tf_js_quickdraw"
dataset_path = os.path.join(base_model_path, "data")
classes = utils.download_dataset(dataset_path, NB_CLASSES)
assert len(classes) == NB_CLASSES, classes
num_classes = FLAGS.num_classes
classes = utils.download_dataset(dataset_path, num_classes)
assert len(classes) == num_classes, classes
print(f"Classes are: {classes}")
print("Loading dataset into memory...")
train_ds, test_ds = utils.load_classes(dataset_path, classes)
print(f"Starting training for {FLAGS.num_epochs} epochs...")
flax_params = train(train_ds, test_ds, classes)

model_dir = os.path.join(base_model_path, "saved_models")
Expand Down

0 comments on commit 9e3001b

Please sign in to comment.