Skip to content

Commit

Permalink
Merge pull request tensorflow#47 from tensorflow/internal-to-github-sync
Browse files Browse the repository at this point in the history
Merge internal changes into public repository (change 181251654)
  • Loading branch information
sb2nov committed Jan 9, 2018
2 parents 1dfee59 + a2a75e0 commit d10a0f2
Show file tree
Hide file tree
Showing 18 changed files with 3,422 additions and 155 deletions.
139 changes: 139 additions & 0 deletions cloud_tpu/datasets/fake_data_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Script to create a fake dataset to test out ResNet-50 and related models.
To run the script setup a virtualenv with the following libraries installed.
- `gcloud`: Follow the instructions on
[cloud SDK docs](https://cloud.google.com/sdk/downloads) followed by
installing the python api using `pip install google-cloud`.
- `tensorflow`: Install with `pip install tensorflow`
- `Pillow`: Install with `pip install pillow`
You can run the script using the following command.
```
python fake_data_generator.py \
--project="TEST_PROJECT" \
--gcs_output_path="gs://TEST_BUCKET/DATA_DIR"
```
"""

import os
import StringIO
import numpy as np
from PIL import Image
import tensorflow as tf

tf.flags.DEFINE_string('project', None,
'Google cloud project id for uploading the dataset.')
tf.flags.DEFINE_string('gcs_output_path', None,
'GCS path for uploading the dataset.')
tf.flags.DEFINE_integer('examples_per_shard', 5000, '')
tf.flags.DEFINE_integer('num_label_classes', 1000, '')
tf.flags.DEFINE_integer('training_shards', 260, '')
tf.flags.DEFINE_integer('validation_shards', 10, '')

FLAGS = tf.flags.FLAGS

TRAINING_PREFIX = 'train'
VALIDATION_PREFIX = 'validation'


def _int64_feature(value):
"""Wrapper for inserting int64 features into Example proto."""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def _bytes_feature(value):
"""Wrapper for inserting bytes features into Example proto."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _create_example(label):
"""Build an Example proto for a single randomly generated image."""
colorspace = 'RGB'
channels = 3
image_format = 'JPEG'
height = 224
width = 224

# Create a random image
image = (np.random.rand(height, width, channels) * 255).astype('uint8')
image = Image.fromarray(image)
image_buffer = StringIO.StringIO()
image.save(image_buffer, format=image_format)

example = tf.train.Example(
features=tf.train.Features(
feature={
'image/height': _int64_feature(height),
'image/width': _int64_feature(width),
'image/colorspace': _bytes_feature(colorspace),
'image/channels': _int64_feature(channels),
'image/class/label': _int64_feature(label),
'image/format': _bytes_feature(image_format),
'image/encoded': _bytes_feature(image_buffer.getvalue())
}))
return example


def _create_random_file(output_file):
"""Create a single tf-record file with multiple examples for each class."""
writer = tf.python_io.TFRecordWriter(output_file)
examples_per_class = int(FLAGS.examples_per_shard / FLAGS.num_label_classes)

assert examples_per_class > 0, 'Number of examples per class should be >= 1'

for label in range(FLAGS.num_label_classes):
for _ in range(examples_per_class):
example = _create_example(label)
writer.write(example.SerializeToString())
writer.close()


def create_tf_records(data_dir):
"""Create random data and write it to tf-record files."""
def _create_records(prefix, num_shards):
"""Create records in a given directory."""
for shard in range(num_shards):
filename = os.path.join(data_dir, '%s-%.5d-of-%.5d' % (prefix, shard,
num_shards))
_create_random_file(filename)

tf.logging.info('Processing the training data.')
_create_records(TRAINING_PREFIX, FLAGS.training_shards)

tf.logging.info('Processing the validation data.')
_create_records(VALIDATION_PREFIX, FLAGS.validation_shards)


def main(argv): # pylint: disable=unused-argument
tf.logging.set_verbosity(tf.logging.INFO)

if FLAGS.project is None:
raise ValueError('GCS Project must be provided.')

if FLAGS.gcs_output_path is None:
raise ValueError('GCS output path must be provided.')
elif not FLAGS.gcs_output_path.startswith('gs://'):
raise ValueError('GCS output path must start with gs://')

# Create fake tf-records
create_tf_records(FLAGS.gcs_output_path)


if __name__ == '__main__':
tf.app.run()
9 changes: 4 additions & 5 deletions cloud_tpu/datasets/imagenet_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,22 +410,21 @@ def upload_to_gcs(training_records, validation_records):
client = storage.Client(project=FLAGS.project)
bucket = client.get_bucket(bucket_name)

def _upload_files(filenames, subdirectory):
def _upload_files(filenames):
"""Upload a list of files into a specifc subdirectory."""
for i, filename in enumerate(sorted(filenames)):
blob = bucket.blob(key_prefix + subdirectory + '/' +
os.path.basename(filename))
blob = bucket.blob(key_prefix + os.path.basename(filename))
blob.upload_from_filename(filename)
if not i % 20:
tf.logging.info('Finished uploading file: %s' % filename)

# Upload training dataset
tf.logging.info('Uploading the training data.')
_upload_files(training_records, TRAINING_DIRECTORY)
_upload_files(training_records)

# Upload validation dataset
tf.logging.info('Uploading the validation data.')
_upload_files(validation_records, VALIDATION_DIRECTORY)
_upload_files(validation_records)


def main(argv): # pylint: disable=unused-argument
Expand Down
59 changes: 59 additions & 0 deletions cloud_tpu/jupyterhub/launch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/bin/bash
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# This script tries to launch the Jupyterhub notebook in this directory if it
# is running in a GCE instance.
#
# 1. Verifies that Jupyterhub, TensorFlow, and gcloud are installed properly
# 2. Tags the current instance with `cloud-tpu-demo-notebook`.
# 3. Creates a firewall rule that opens port 8888 (for Jupyterhub) and port
# 6006 (for TensorBoard) for all instances tagged `cloud-tpu-demo-notebook`.
# 4. Starts Jupyterhub.

version_lte() {
[ "$1" = "`echo -e "$1\n$2" | sort -V | head -n1`" ]
}

# Ensure that pip is installed
command -v pip >/dev/null 2>&1 || { echo "To run this tutorial, we need pip to be installed. You can install pip by running `sudo apt-get install python-pip`."; exit 1; }

# Ensure that gcloud is installed
command -v gcloud >/dev/null 2>&1 || { echo "To run this tutorial, we need the Google Cloud SDK. Please see https://cloud.google.com/sdk/downloads for instructions."; exit 1; }

# Ensure that TensorFlow is installed
TF_VERSION=`python -c "import tensorflow; print tensorflow.__version__" 2>/dev/null`
version_lte $TF_VERSION 1.5 && (echo "Your version of TensorFlow is too low. You must install at least version 1.5.0."; exit 1;)

# Ensure that Jupyter is installed
command -v jupyter >/dev/null 2>&1 || { sudo pip install jupyter; }

# Retrieve the instance name and zone of the current instance
INSTANCE_NAME=`curl -H "Metadata-Flavor: Google" http://metadata.google.internal/computeMetadata/v1/instance/name 2>/dev/null`
INSTANCE_ZONE=`curl -H "Metadata-Flavor: Google" http://metadata.google.internal/computeMetadata/v1/instance/zone 2>/dev/null`

# Add `cloud-tpu-demo-notebook` tag to current instance
gcloud compute instances add-tags $INSTANCE_NAME --tags cloud-tpu-demo-notebook --zone $INSTANCE_ZONE

# Add firewall rule to open tcp:6006,8888 for `cloud-tpu-demo-notebook`
gcloud compute firewall-rules create cloud-tpu-demo-notebook --target-tags=cloud-tpu-demo-notebook --allow=tcp:6006,tcp:8888

# Print out JupyterHub URL
echo
echo The Jupyterhub is at: http://`curl -H "Metadata-Flavor: Google" http://metadata/computeMetadata/v1/instance/network-interfaces/0/access-configs/0/external-ip 2> /dev/null`:8888/
echo

# Launch JupyterHub
jupyter notebook --no-browser --ip=0.0.0.0
146 changes: 146 additions & 0 deletions cloud_tpu/models/cifar_keras/cifar_keras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Cifar example using Keras for model definition."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer

tf.flags.DEFINE_integer("batch_size", 128,
"Mini-batch size for the computation. Note that this "
"is the global batch size and not the per-shard batch.")
tf.flags.DEFINE_float("learning_rate", 0.05, "Learning rate.")
tf.flags.DEFINE_string("train_file", "", "Path to cifar10 training data.")
tf.flags.DEFINE_integer("train_steps", 100000,
"Total number of steps. Note that the actual number of "
"steps is the next multiple of --iterations greater "
"than this value.")
tf.flags.DEFINE_bool("use_tpu", True, "Use TPUs rather than plain CPUs")
tf.flags.DEFINE_string("master", "",
"BNS name of the TensorFlow master to use.")
tf.flags.DEFINE_string("model_dir", None, "Estimator model_dir")
tf.flags.DEFINE_integer("iterations", 100,
"Number of iterations per TPU training loop.")
tf.flags.DEFINE_integer("num_shards", 8, "Number of shards (TPU chips).")


FLAGS = tf.flags.FLAGS


def model_fn(features, labels, mode, params):
"""Define a CIFAR model in Keras."""
del params # unused
layers = tf.contrib.keras.layers

# Pass our input tensor to initialize the Keras input layer.
v = layers.Input(tensor=features)
v = layers.Conv2D(filters=32, kernel_size=5,
activation="relu", padding="same")(v)
v = layers.MaxPool2D(pool_size=2)(v)
v = layers.Conv2D(filters=64, kernel_size=5,
activation="relu", padding="same")(v)
v = layers.MaxPool2D(pool_size=2)(v)
v = layers.Flatten()(v)
fc1 = layers.Dense(units=512, activation="relu")(v)
logits = layers.Dense(units=10)(fc1)

# Instead of constructing a Keras model for training, build our loss function
# and optimizer in Tensorflow.
#
# N.B. This construction omits some features that are important for more
# complex models (e.g. regularization, batch-norm). Once
# `model_to_estimator` support is added for TPUs, it should be used instead.
loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=labels
)
)
optimizer = tf.train.AdamOptimizer()
if FLAGS.use_tpu:
optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)

train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())

return tpu_estimator.TPUEstimatorSpec(
mode=mode,
loss=loss,
train_op=train_op,
predictions={
"classes": tf.argmax(input=logits, axis=1),
"probabilities": tf.nn.softmax(logits, name="softmax_tensor")
}
)


def input_fn(params):
"""Read CIFAR input data from a TFRecord dataset."""
del params
batch_size = FLAGS.batch_size
def parser(serialized_example):
"""Parses a single tf.Example into image and label tensors."""
features = tf.parse_single_example(
serialized_example,
features={
"image": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.int64),
})
image = tf.decode_raw(features["image"], tf.uint8)
image.set_shape([3*32*32])
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
image = tf.transpose(tf.reshape(image, [3, 32, 32]))
label = tf.cast(features["label"], tf.int32)
return image, label

dataset = tf.data.TFRecordDataset([FLAGS.train_file])
dataset = dataset.map(parser, num_parallel_calls=batch_size)
dataset = dataset.prefetch(4 * batch_size).cache().repeat()
dataset = dataset.apply(
tf.contrib.data.batch_and_drop_remainder(FLAGS.batch_size)
)
dataset = dataset.prefetch(1)
images, labels = dataset.make_one_shot_iterator().get_next()
return images, labels


def main(argv):
del argv # Unused.

run_config = tpu_config.RunConfig(
master=FLAGS.master,
model_dir=FLAGS.model_dir,
save_checkpoints_secs=3600,
session_config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=True),
tpu_config=tpu_config.TPUConfig(
iterations_per_loop=FLAGS.iterations, num_shards=FLAGS.num_shards),
)

estimator = tpu_estimator.TPUEstimator(
model_fn=model_fn,
use_tpu=FLAGS.use_tpu,
config=run_config,
train_batch_size=FLAGS.batch_size)
estimator.train(input_fn=input_fn, max_steps=FLAGS.train_steps)


if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run(main)
Loading

0 comments on commit d10a0f2

Please sign in to comment.