-
Notifications
You must be signed in to change notification settings - Fork 301
Add cloud training support for BERT example #226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add cloud training support for BERT example #226
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments. I think overall, we need to support a few more things here.
- It should be possible to mix and match gcs buckets.
- It should be possible to run locally from data on a public gcs bucket, with local logging and saving.
I think the simplest way to do this would be to try to make our inputs behave like the gfile API. Where anything with a gs://
prefix is handled correctly. I think we can for the most part do that by just calling gfile routines, instead of forking our code based on google cloud storage or no.
"Skip restoring from checkpoint if True", | ||
) | ||
|
||
flags.DEFINE_bool( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does the full invocation end up looking like? Could we make this more simple where all of these paths could be gs://
paths? Similar to how gfile handles this? So something like
bert_train.py \
--input_files=gs://foo-bucket/foo-bert-data \
--checkpoint_save_directory=gs://bar-bucket/training-01/pretrained-checkpoints \
--saved_model_output=gs://bar-bucket/training-01/pretrained-model \
...
I don't think we should demand that logging and source data are coming from the same bucket. And overall I think a flow like this would be simpler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed!
examples/bert/bert_train.py
Outdated
import shutil | ||
import sys | ||
|
||
import google.cloud.logging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Won't this import break the current quick start unless we are pip installing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed!
examples/bert/bert_train.py
Outdated
def get_checkpoint_callback(): | ||
if FLAGS.checkpoint_save_directory is not None: | ||
if FLAGS.use_cloud_storage: | ||
storage_client = storage.Client() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just use GFile as a consistent interface? E.g. https://www.tensorflow.org/api_docs/python/tf/io/gfile/rmtree
GFile says it should support gs://
prefix for the entire API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed!
d7affc2
to
1c51bfd
Compare
Quick notes from meeting
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This looks good to me.
Let's make sure we test the simple local invocation in the README before we merge, to make sure we didn't break anythin.
Checked! The quick start still works! |
keras_nlp.TransformerEncoder.