-
Notifications
You must be signed in to change notification settings - Fork 25.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flax Masked Language Modeling training example #8728
Conversation
# Model forward | ||
# TODO: Remove this conversion by replacing the collator | ||
model_inputs = {var_name: tensor.numpy() for var_name, tensor in model_inputs.items()} | ||
loss, optimizer = training_step(optimizer, model_inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really cool!
Can you also easily have a learning rate schedule?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yap sure, didn't put it in at first, focusing on making things clear and almost a no-brainer 😄.
Will look at it soon 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now included 👍
for epoch in track(range(int(training_args.num_train_epochs)), description="Training..."): | ||
samples_idx = np.random.choice(len(tokenized_datasets["train"]), (training_args.train_batch_size, )) | ||
samples = [tokenized_datasets["train"][idx.item()] for idx in samples_idx] | ||
model_inputs = data_collator(samples) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's your first impression feeling on having a FLAX Trainer with a similar API to the PT Trainer at some point @sgugger?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't look like it's going to be too hard to build.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Argh, you actually took the one example that is a bit flaky (we merged DataCollatorForWholeWordMasking
a bit too fast and the data preprocessing part of this script needs to be completely rewritten as it works for BERT only for now).
Could you do the same with the run_mlm
script instead? This one won't change :-)
for epoch in track(range(int(training_args.num_train_epochs)), description="Training..."): | ||
samples_idx = np.random.choice(len(tokenized_datasets["train"]), (training_args.train_batch_size, )) | ||
samples = [tokenized_datasets["train"][idx.item()] for idx in samples_idx] | ||
model_inputs = data_collator(samples) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't look like it's going to be too hard to build.
3bb9f4f
to
7ffb79a
Compare
Here is the full list of checkpoints on the hub that can be fine-tuned by this script: | ||
https://huggingface.co/models?filter=masked-lm | ||
""" | ||
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments. | |
# You can also adapt this script to your own masked language modeling task. Pointers for this are left as comments. |
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments. | |
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments. |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find it quite useful to have a comment at the top of a binary with an example command-line command allowing users to run this code directly. What do you think of this?
@dataclass | ||
class ModelArguments: | ||
""" | ||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems we will never fine-tune with this code, right? At least it looks like the model always is FlaxBertForMaskedLM, which has a pre-training objective.
if self.train_file is not None: | ||
extension = self.train_file.split(".")[-1] | ||
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." | ||
if self.validation_file is not None: | ||
extension = self.validation_file.split(".")[-1] | ||
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe create some inner function check_file_extension
to avoid code duplication?
# return -jnp.mean(jnp.sum(one_hot(labels, config.vocab_size) * logits, axis=-1), axis=-1) | ||
# | ||
|
||
def cross_entropy(logits, targets, label_smoothing=0.0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe factor this function out of training_step
to make it easier to read?
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
ffc1f34
to
7dd4a85
Compare
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Include a training example running with Flax/JAX framework. (cc @avital @marcvanzee)
TODOs:
bfloat16
)float16
)