-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core
] Add hub utils
#39
Changes from 3 commits
2896cf0
2cc7f2c
634f369
ad69958
16182ea
22295c4
6c9534e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# coding=utf-8 | ||
# Copyright 2023-present the HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
WEIGHTS_NAME = "adapter_model.bin" | ||
CONFIG_NAME = "adapter_config.json" | ||
|
||
# TODO: add automapping and superclass here? |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,11 +12,16 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import json | ||
import enum | ||
from dataclasses import dataclass, field | ||
from dataclasses import dataclass, field, asdict | ||
from typing import Optional, Union | ||
|
||
from transformers.utils import PushToHubMixin | ||
from huggingface_hub import hf_hub_download | ||
|
||
from .adapters_utils import CONFIG_NAME | ||
|
||
class PeftType(str, enum.Enum): | ||
PROMPT_TUNING = "PROMPT_TUNING" | ||
|
@@ -33,7 +38,93 @@ class TaskType(str, enum.Enum): | |
|
||
|
||
@dataclass | ||
class PeftConfig: | ||
class PeftConfigMixin(PushToHubMixin): | ||
r""" | ||
This is the base configuration class for PEFT adapter models. It contains all the methods that | ||
are common to all PEFT adapter models. | ||
This class inherits from `transformers.utils.PushToHubMixin` which contains the methods to push your model to the Hub. | ||
The method `save_pretrained` will save the configuration of your adapter model in a directory. | ||
The method `from_pretrained` will load the configuration of your adapter model from a directory. | ||
|
||
Args: | ||
peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use. | ||
""" | ||
peft_type: Optional[PeftType] = field(default=None, metadata={"help": "The type of PEFT model."}) | ||
|
||
@property | ||
def __dict__(self): | ||
return asdict(self) | ||
|
||
def to_dict(self): | ||
return self.__dict__ | ||
|
||
def save_pretrained(self, save_directory, **kwargs): | ||
r""" | ||
This method saves the configuration of your adapter model in a directory. | ||
|
||
Args: | ||
save_directory (`str`): | ||
The directory where the configuration will be saved. | ||
**kwargs: | ||
Additional keyword arguments passed along to the `transformers.utils.PushToHubMixin.push_to_hub` method. | ||
""" | ||
if os.path.isfile(save_directory): | ||
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") | ||
|
||
os.makedirs(save_directory, exist_ok=True) | ||
|
||
output_dict = self.__dict__ | ||
output_path = os.path.join(save_directory, CONFIG_NAME) | ||
|
||
# save it | ||
with open(output_path, "w") as writer: | ||
writer.write(json.dumps(output_dict, indent=2, sort_keys=True)) | ||
|
||
@classmethod | ||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): | ||
r""" | ||
This method loads the configuration of your adapter model from a directory. | ||
|
||
Args: | ||
pretrained_model_name_or_path (`str`): | ||
The directory or the hub-id where the configuration is saved. | ||
**kwargs: | ||
Additional keyword arguments passed along to the child class initialization. | ||
""" | ||
if os.path.isfile(os.path.join(pretrained_model_name_or_path, CONFIG_NAME)): | ||
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) | ||
else: | ||
try: | ||
config_file = hf_hub_download(pretrained_model_name_or_path, CONFIG_NAME) | ||
except: | ||
raise ValueError(f"Can't find config.json at '{pretrained_model_name_or_path}'") | ||
|
||
loaded_attributes = cls.from_json_file(config_file) | ||
|
||
config = cls(**kwargs) | ||
|
||
for key, value in loaded_attributes.items(): | ||
if hasattr(config, key): | ||
setattr(config, key, value) | ||
|
||
return config | ||
|
||
@classmethod | ||
def from_json_file(cls, path_json_file, **kwargs): | ||
r""" | ||
Loads a configuration file from a json file. | ||
|
||
Args: | ||
path_json_file (`str`): | ||
The path to the json file. | ||
""" | ||
with open(path_json_file, 'r') as file: | ||
json_object = json.load(file) | ||
|
||
return json_object | ||
|
||
|
||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class PeftConfig(PeftConfigMixin): | ||
""" | ||
This is the base configuration class to store the configuration of a :class:`~peft.PeftModel`. | ||
|
||
|
@@ -42,13 +133,11 @@ class PeftConfig: | |
task_type (Union[[`~peft.utils.config.TaskType`], `str`]): The type of task to perform. | ||
inference_mode (`bool`, defaults to `False`): Whether to use the Peft model in inference mode. | ||
""" | ||
|
||
peft_type: Union[str, PeftType] = field(default=None, metadata={"help": "Peft type"}) | ||
task_type: Union[str, TaskType] = field(default=None, metadata={"help": "Task type"}) | ||
inference_mode: bool = field(default=False, metadata={"help": "Whether to use inference mode"}) | ||
|
||
|
||
@dataclass | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is required else similar error as seen in the previous comment There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! In any case I added tests to test that the script you shared above will not fail |
||
class PromptLearningConfig(PeftConfig): | ||
""" | ||
This is the base configuration class to store the configuration of a Union[[`~peft.PrefixTuning`], | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there any specific advantage to using
transformers.PushToHubMixin
vshuggingface_hub
's? the latter means you only have to implement a single_save_pretrained()
method, but perhaps you need more flexibilityThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was not aware that the class you propose is simpler, I will have a look now and potentially replace it !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried it and in both cases I needed to re-implement
from_pretrained
method, also it seems thatpush_to_hub
fromPushToHubMixin
first clones a repo with the weights before pushing on the Hub, maybe I'll just stick ontransformers.PushToHubMixin
for now :D