-
Notifications
You must be signed in to change notification settings - Fork 218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Task Saving/Loading #1547
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Just dropping some initial feedback!
keras_nlp/models/preprocessor.py
Outdated
|
||
def save_to_preset(self, preset): | ||
"""TODO: add docstring.""" | ||
save_to_preset( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just call self.tokenizer.save_to_preset(preset)
here? Also, I wonder if we should update the name of the preset
arg to path
. Would make it clearer we are looking for a filesystem path here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
keras_nlp/models/task.py
Outdated
|
||
self.preprocessor.save_to_preset(preset) | ||
self.backbone.save_to_preset(preset) | ||
weights_filename = "task.weights.h5" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep this as a constant? Just for consistency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! Done!
keras_nlp/models/task.py
Outdated
self.backbone.save_to_preset(preset) | ||
weights_filename = "task.weights.h5" | ||
|
||
# TODO: the serialization and saving logic should probably be moved to preset_utils.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, not exactly sure where the divisions should live, but we should probably clear up the division between the model code and saving utils a bit.
keras_nlp/models/task.py
Outdated
) | ||
weights_store.close() | ||
|
||
# TODO: do we want to have a `save_weights` flag in this public save_to_preset? probably yes! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To save the architecture without weights? What's the use case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know what I was thinking when I wrote this :))
I was probably thinking about adding a load_weights
flag to anywhere we load weights (similar to what we have in load_from_preset
)!
keras_nlp/models/task.py
Outdated
weights_store = keras.src.saving.saving_lib.H5IOStore( | ||
filepath, mode="r" | ||
) | ||
# Q: when loading task weights, there shouldn't be any backbone layers, why calculate and exclude backbone layers? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise I believe _load_weights
will fail because it will try to load the backbone layers weights and not find them in the file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see! Thanks for the explanation.
keras_nlp/models/task.py
Outdated
task = keras.saving.deserialize_keras_object(task_config) | ||
load_weights = load_weights and task_config["weights"] | ||
task_weights_path = os.path.join(preset, task_config["weights"]) | ||
task.load_task_weights(task_weights_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to call task.backbone.load_weights
too somewhere right? Where do the backbone weights get loaded?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keras_nlp/models/task.py
Outdated
@@ -267,59 +274,162 @@ def from_preset( | |||
"constructor with a `backbone` argument. " | |||
f"Received: backbone={kwargs['backbone']}." | |||
) | |||
|
|||
# Load backbone from preset. | |||
config_path = os.path.join(preset, CONFIG_FILE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to rework this to a slightly new flow. Something like this...
# Backbone case.
if not exists("task.json") or not issubclass(check_config_class("task.json"), cls):
# This should be basically what is here already.
# Load a preprocessor. Load a backbone.
return cls(backbone=backbone, preprocessor=preprocessor, **kwargs)
# Task case.
task = keras.saving.deserialize_keras_object("task.json")
if load_weights:
task.backbone.load_weights("model.weights.h5")
task.load_task_weights("task.weights.h5")
Basically if we don't see a task.json
or the task.json
is for a different task, we load the low level objects and make a default task object with them. If we find a task object for our class, we load it exactly as it was before.
Let me know if that makes sense. I think the logic here is a bit different.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I was thinking was to load the backbone and preprocessor in any case (whether we have a task.json
or not).
- If there is no
task.json
make a default task object with backbone and preprocessor. - If there is a
task.json
, load the task-specific things and assign the loadedbackbone
totask.backbone
.
Do you see any issues or disadvantages with this approach?
I'm mainly doing this because I thought it would be cleaner to just have one piece of code that loads the backbone.
PS: I made some changes this morning so they may not have been included in the version of code that you reviewed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assigning the backbone to task would mean we double up on backbone memory until the next GC. We have to avoid that I think, otherwise we will OOM people very easily.
One case I was getting at in my snippet but worth calling out explicitly. You saved, say, a BertClassifier but are loading a different task, e.g. a BertMaskedLM. In this case our task.json
is from the wrong object, and I think we fall back to the "backbone case" here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see! Makes sense! Thanks for explaining this, Matt!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review, Matt!
keras_nlp/models/task.py
Outdated
|
||
self.preprocessor.save_to_preset(preset) | ||
self.backbone.save_to_preset(preset) | ||
weights_filename = "task.weights.h5" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! Done!
keras_nlp/models/task.py
Outdated
task = keras.saving.deserialize_keras_object(task_config) | ||
load_weights = load_weights and task_config["weights"] | ||
task_weights_path = os.path.join(preset, task_config["weights"]) | ||
task.load_task_weights(task_weights_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keras_nlp/models/task.py
Outdated
weights_store = keras.src.saving.saving_lib.H5IOStore( | ||
filepath, mode="r" | ||
) | ||
# Q: when loading task weights, there shouldn't be any backbone layers, why calculate and exclude backbone layers? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see! Thanks for the explanation.
keras_nlp/models/task.py
Outdated
) | ||
weights_store.close() | ||
|
||
# TODO: do we want to have a `save_weights` flag in this public save_to_preset? probably yes! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know what I was thinking when I wrote this :))
I was probably thinking about adding a load_weights
flag to anywhere we load weights (similar to what we have in load_from_preset
)!
keras_nlp/models/task.py
Outdated
@@ -267,59 +274,162 @@ def from_preset( | |||
"constructor with a `backbone` argument. " | |||
f"Received: backbone={kwargs['backbone']}." | |||
) | |||
|
|||
# Load backbone from preset. | |||
config_path = os.path.join(preset, CONFIG_FILE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I was thinking was to load the backbone and preprocessor in any case (whether we have a task.json
or not).
- If there is no
task.json
make a default task object with backbone and preprocessor. - If there is a
task.json
, load the task-specific things and assign the loadedbackbone
totask.backbone
.
Do you see any issues or disadvantages with this approach?
I'm mainly doing this because I thought it would be cleaner to just have one piece of code that loads the backbone.
PS: I made some changes this morning so they may not have been included in the version of code that you reviewed.
981a150
to
fa3c6fe
Compare
…r.json doesn't exist.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Looking good. I think there's two main question I see.
How we structure proprocessor and task loading, and what we save in our json files. They are inter related.
keras_nlp/models/backbone.py
Outdated
save_serialized_object(self, preset, config_file=CONFIG_FILE) | ||
save_weights(self, preset, MODEL_WEIGHTS_FILE) | ||
save_metadata(self, preset) | ||
# save_to_preset(self, preset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove commented out code?
keras_nlp/models/task.py
Outdated
filter(lambda x: x.backbone_cls == preset_cls, subclasses) | ||
|
||
task = None | ||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this could get a lot more readable if we made a check_file_exists
or similarly named util. The try/except and wrapping our get_file
in preset utils with a FileNotFoundError
makes for a weird interface.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added check_file_exists
.
keras_nlp/models/task.py
Outdated
objects_to_skip=backbone_layer_ids, | ||
) | ||
|
||
def save_weights(self, filepath): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably call this save_task_weights
. It's different than vanilla save_weights
for Keras, we should name and document it differently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
keras_nlp/tokenizers/tokenizer.py
Outdated
make_preset_dir(preset) | ||
save_tokenizer_assets(self, preset) | ||
save_serialized_object(self, preset, config_file=TOKENIZER_CONFIG_FILE) | ||
# save_to_preset(self, preset, config_filename=TOKENIZER_CONFIG_FILE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
keras_nlp/models/task.py
Outdated
Args: | ||
preset: The path to the local model preset directory. | ||
""" | ||
check_keras_version() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we move make_preset_dir(preset)
and check_keras_version()
down into the other utilities? E.g. into save_serialized_object
?
That would make these top-level functions a little less cluttered and easier to read.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! Moved these two to save_serialized_object
.
keras_nlp/models/task.py
Outdated
return task | ||
return cls(backbone=backbone, preprocessor=preprocessor, **kwargs) | ||
|
||
def load_weights(self, filepath): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably call this load_task_weights
. It's different than vanilla save_weights for Keras, we should name and document it differently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
keras_nlp/models/task.py
Outdated
load_weights=load_weights, | ||
config_overrides=kwargs, | ||
config_file=TASK_CONFIG_FILE, | ||
config_to_skip=["preprocessor", "backbone"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To discuss, but I'm not sure we should do this.
For weights we are saving the task bits separately, but that's really because weights are huge. We can't afford to save backbone weights and task weights separately in their entirety.
For configs, everything is small. We can duplicate, and can just effectively do keras.saving.serialize_keras_object(task)
here. That means duplicated config between tokenizer.json
, backbone.json
, preprocessor.json
and task.json
. But we don't care. It's lightweight, makes our code simpler, and most importantly keep our assets simple that we put in the format. Any user can call keras.saving.deserialize_keras_object(task_json)
if they are so inclined (thought that won't handle weight loading).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we discussed offline, we'll allow config duplication, i.e. task.json
includes backbone config and preprocessor.json
includes tokenizer config.
keras_nlp/models/task.py
Outdated
backbone = load_from_preset( | ||
backbone_config = load_config(preset, CONFIG_FILE) | ||
# TODO: this is not really an override! It's an addition! Should I rename this? | ||
config_overrides = {"backbone": backbone_config} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just make a comment below. But I think we might want to keep our json objects really simple, so we don't need to patch them like this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
keras_nlp/models/preprocessor.py
Outdated
cls = subclasses[0] | ||
tokenizer = load_from_preset( | ||
|
||
# For backward compatibility, if preset doesn't have `preprocessor.json` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, had a comment here I must have forgot to hit save on. To discuss, but I think there are two main cases for task and preprocessor loading. Neither are backward compat.
# Preprocessor load.
if exists("preprocessor.json") and is_class("preprocessor.json", cls):
preprocessor = load_serialized_object(preset, "preprocessor.json", **kwargs)
load tokenizer assets
else:
# Load from sub objects and create with default config.
tokenizer = tokenizer_cls.from_preset(preset)
preprocessor = cls(tokenizer=tokenizer)
# Task load
if exists("task.json") and is_class("task.json", cls):
task = load_serialized_object(preset, "task.json", **kwargs)
load weights
load tokenizer assets
else:
# Load from sub objects and create with default config.
backbone = backbone_cls.from_preset(preset)
preprocessor = preprocess_cls.from_preset(preset)
task = cls(backbone=backbone, preprocessor=preprocessor)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review, Matt!
keras_nlp/models/task.py
Outdated
filter(lambda x: x.backbone_cls == preset_cls, subclasses) | ||
|
||
task = None | ||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added check_file_exists
.
keras_nlp/models/task.py
Outdated
backbone = load_from_preset( | ||
backbone_config = load_config(preset, CONFIG_FILE) | ||
# TODO: this is not really an override! It's an addition! Should I rename this? | ||
config_overrides = {"backbone": backbone_config} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
keras_nlp/models/task.py
Outdated
return task | ||
return cls(backbone=backbone, preprocessor=preprocessor, **kwargs) | ||
|
||
def load_weights(self, filepath): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
keras_nlp/models/task.py
Outdated
objects_to_skip=backbone_layer_ids, | ||
) | ||
|
||
def save_weights(self, filepath): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
keras_nlp/models/task.py
Outdated
Args: | ||
preset: The path to the local model preset directory. | ||
""" | ||
check_keras_version() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! Moved these two to save_serialized_object
.
keras_nlp/models/task.py
Outdated
load_weights=load_weights, | ||
config_overrides=kwargs, | ||
config_file=TASK_CONFIG_FILE, | ||
config_to_skip=["preprocessor", "backbone"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we discussed offline, we'll allow config duplication, i.e. task.json
includes backbone config and preprocessor.json
includes tokenizer config.
keras_nlp/tokenizers/tokenizer.py
Outdated
make_preset_dir(preset) | ||
save_tokenizer_assets(self, preset) | ||
save_serialized_object(self, preset, config_file=TOKENIZER_CONFIG_FILE) | ||
# save_to_preset(self, preset, config_filename=TOKENIZER_CONFIG_FILE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Few more comments.
keras_nlp/models/preprocessor.py
Outdated
preset, | ||
PREPROCESSOR_CONFIG_FILE, | ||
) | ||
for asset in preprocessor.tokenizer.file_assets: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't we get rid of the function below? and just do
for asset in preprocessor.tokenizer.file_assets:
filename = get_file(preset, os.path.join(TOKENIZER_ASSET_DIR, asset))
dirname = os.path.dirname(filename)
seem simpler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
keras_nlp/models/task.py
Outdated
preset, | ||
config_file="tokenizer.json", | ||
load_weights=load_weights, | ||
config_overrides=config_overrides, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Below this, we can have
preprocessor = cls.preprocessor_cls.from_preset(preset)
return cls(backbone=backbone, preprocessor=preprocessor, **kwargs)
and then we are done I think. No need for the rest of this function. We can delegate to preprocessor.from_preset
which has the logic you have below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right! Preprocessor should have all the necessary validation logic to prevent repetition!
Done!
keras_nlp/utils/preset_utils.py
Outdated
): | ||
"""Validate a preset is being loaded on the correct class.""" | ||
config_path = get_file(preset, config_file) | ||
with open(config_path) as config_file: | ||
config = json.load(config_file) | ||
return keras.saving.get_registered_object(config["registered_name"]) | ||
|
||
|
||
def get_asset_dir( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm not sure this util needs to exist. see comment above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
I think we are missing a couple edge cases. But this looks really solid and readable overall!
keras_nlp/models/task.py
Outdated
) | ||
|
||
save_serialized_object(self, preset_dir, config_file=TASK_CONFIG_FILE) | ||
self.save_task_weights(get_file(preset_dir, TASK_WEIGHTS_FILE)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to think more about the case where our task has no new weights. For a lot of language modeling stuff, that will be the norm. I think the high level behavior we want is
task = something where all weight are in backbone
task.save_to_preset("dir") # ok! no task.weights.h5 created
task.save_task_weights("task.weights.h5") # probably good to error like we do now.
Not sure how to best handle this in code, just try/catch on this line? Add a self.has_task_weights()
method?
We need to handle this on the loading side too. Skip load_task_weights
if the file does not exist.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a self.has_task_weights()
to check weights exists before saving. I do this check again in save_task_weights()
so task.save_task_weights()
is a complete function on its own!
For loading in from_preset
, I added a check to only load weights if the file exists.
Let me know what you think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great! Some final nits
keras_nlp/utils/preset_utils.py
Outdated
message = str(e) | ||
if message.find("403 Client Error"): | ||
raise FileNotFoundError( | ||
f"`{path}` doesn't exist in preset directory `{preset}`.\n" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think usually we don't need a trailing \n
in error messages. Would format strangely. Only between lines we want to separate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right! Removed \n
!
keras_nlp/utils/preset_utils.py
Outdated
message = str(e) | ||
if message.find("403 Client Error"): | ||
raise FileNotFoundError( | ||
f"`{path}` doesn't exist in preset directory `{preset}`.\n" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
keras_nlp/utils/preset_utils.py
Outdated
local_path = os.path.join(preset, path) | ||
if not os.path.exists(local_path): | ||
raise FileNotFoundError( | ||
f"`{path}` doesn't exist in preset directory `{preset}`.\n" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
keras_nlp/utils/preset_utils.py
Outdated
weights_filename="model.weights.h5", | ||
): | ||
"""Save a KerasNLP layer to a preset directory.""" | ||
def check_keras_version(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: check_keras_3()
. might improve readability
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
To support saving and loading
Task
, the following changes have been made:task.json
andtask.weights.h5
.Preprocessor
(addedpreprocessor.json
).preset_utils.py
.TODO:
preset_utils.py
.Future plan: currently backbone and config are called
config.json
andmodel.weights.h5
. Our plan is to rename these tobackbone.json
andbackbone.weights.h5
in a followup PR.