New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding support for distributed TensorFlow. #14
Conversation
isn't informative anymore.
… on line 149 of average_precision_calculator.py from list type to zip-type again as per Python 3 changes (#17)
train.py
Outdated
data_pattern + "'.") | ||
logging.info("Number of training files: %s.", str(len(files))) | ||
filename_queue = tf.train.string_input_producer( | ||
files, num_epochs=num_epochs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add shuffle here...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the parameter "shuflle=True" to string_input_producer.
It is the default but it seems that being explicit can improve readability in this case.
task_as_string(self.task), self.cluster.as_dict()) | ||
server = start_server(self.cluster, self.task) | ||
target = server.target | ||
device_fn = tf.train.replica_device_setter( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to tell what the ps_device is.
I don't know what merge_devices does, but I haven't used it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the parameter '"ps_device="/job:ps"'. It is the default value but it seems that being explicit can improve readability.
Removing the parameter "merge_devices=True". It is also the default but this parameter is on the path to deprecation and specifying "merge_devices=False" triggers a warning.
…ultiple parameter servers. This CR also contains a couple changes to explicitely specify some default parameters to increase readability.
train.py
Outdated
flags.DEFINE_string("optimizer", "AdamOptimizer", | ||
"What optimizer class to use.") | ||
flags.DEFINE_bool("log_device_placement", False, | ||
"Whether device placement should be logged.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about "Whether to write the device every op will run on into the logs on startup".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to "Whether to write the device on which every op will run into the logs on startup."
train.py
Outdated
training_data = [ | ||
reader.prepare_reader(filename_queue) for _ in xrange(num_readers)] | ||
reader.prepare_reader(filename_queue) for _ in xrange(num_readers) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please sync to head and test this in Python3. We are trying to maintain compatibility with both python versions now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Synced to head.
Tested with Python3:
python3 youtube-8m-vicaire/train.py --train_data_pattern='/tmp/features/train*.tfrecord' --train_dir=/tmp/features/video_level_logistic_model
So there's good news and bad news. 👍 The good news is that everyone that needs to sign a CLA (the pull request submitter and all commit authors) have done so. Everything is all good there. 😕 The bad news is that it appears that one or more commits were authored by someone other than the pull request submitter. We need to confirm that they're okay with their commits being contributed to this project. Please have them confirm that here in the pull request. Note to project maintainer: This is a terminal state, meaning the |
Adding support for distributed TensorFlow.
Tests:
Local execution, non distributed:
gcloud --verbosity=debug beta ml local train --package-path=youtube-8m-private --module-name=youtube-8m-private.train -- --train_data_pattern='gs://youtube8m-ml/1/video_level/train/train*.tfrecord' --train_dir=/tmp/yt8m_train --start_new_model
Local execution, distributed:
gcloud beta ml local train --package-path=youtube-8m-private --module-name=youtube-8m-private.train --distributed --parameter-server-count=1 --worker-count=4 -- --train_data_pattern='gs://youtube8m-ml/1/video_level/train/train*.tfrecord' --train_dir=/tmp/yt8m_train --start_new_model
Running on your own machine, python 2.7:
python youtube-8m-private/train.py --train_data_pattern='/tmp/features/train*.tfrecord' --train_dir=/tmp/features/video_level_logistic_model --start_new_model
Running on your own machine, python 3:
python3 youtube-8m-vicaire/train.py --train_data_pattern='/tmp/features/train*.tfrecord' --train_dir=/tmp/features/video_level_logistic_model
Distributed execution on cloud:
BUCKET_NAME=...; JOB_NAME=yt8m_train_$(date +%Y%m%d_%H%M%S); gcloud --verbosity=debug beta ml jobs submit training $JOB_NAME --package-path=youtube-8m-private --module-name=youtube-8m-private.train --staging-bucket=$BUCKET_NAME --region=us-central1 --config=youtube-8m-private/cloudml-gpu-distributed.yaml -- --train_data_pattern='gs://youtube8m-ml/1/video_level/train/train*.tfrecord' --train_dir=$BUCKET_NAME/yt8m_train_video_level_logistic_model --start_new_model
Non-distributed execution on cloud:
BUCKET_NAME=...; JOB_NAME=yt8m_train_$(date +%Y%m%d_%H%M%S); gcloud --verbosity=debug beta ml jobs submit training $JOB_NAME --package-path=youtube-8m-private --module-name=youtube-8m-private.train --staging-bucket=$BUCKET_NAME --region=us-central1 --config=youtube-8m-private/cloudml-gpu.yaml -- --train_data_pattern='gs://youtube8m-ml/1/video_level/train/train*.tfrecord' --train_dir=$BUCKET_NAME/yt8m_train_video_level_logistic_model --start_new_model