Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Add hub utils #39

Merged
merged 7 commits into from
Jan 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,6 @@ def get_peft_model(model, peft_config):
else:
peft_config = _prepare_lora_config(peft_config, model_config)

peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)

return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config)
88 changes: 84 additions & 4 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,31 @@
# limitations under the License.

import inspect
import os
import warnings

import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput, TokenClassifierOutput
from transformers.utils import PushToHubMixin

from .tuners import LoraModel, PrefixEncoder, PromptEmbedding, PromptEncoder
from .utils import PeftConfig, PeftType, TaskType, _set_trainable, shift_tokens_right

from huggingface_hub import hf_hub_download

class PeftModel(torch.nn.Module):
from .tuners import LoraModel, PrefixEncoder, PromptEmbedding, PromptEncoder
from .utils import (
WEIGHTS_NAME,
PeftConfig,
PeftType,
TaskType,
_set_trainable,
get_peft_model_state_dict,
set_peft_model_state_dict,
shift_tokens_right,
)


class PeftModel(PushToHubMixin, torch.nn.Module):
"""
Parameter-Efficient Fine-Tuning Model. Base model encompassing various Peft methods.

Expand Down Expand Up @@ -61,6 +74,73 @@ def __init__(self, model, peft_config: PeftConfig):
self.base_model = LoraModel(peft_config, model)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def save_pretrained(self, save_directory, **kwargs):
r"""
Args:
This function saves the adapter model and the adapter configuration files to a directory, so that it can be
re-loaded using the `LoraModel.from_pretrained` class method, and also used by the `LoraModel.push_to_hub`
method.
save_directory (`str`):
Directory where the adapter model and configuration files will be saved (will be created if it does not
exist).
**kwargs:
Additional keyword arguments passed along to the `push_to_hub` method.
"""
if os.path.isfile(save_directory):
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)

# save the config
if self.peft_config.base_model_name_or_path is None:
self.peft_config.base_model_name_or_path = self.base_model.__dict__.get("name_or_path", None)
self.peft_config.inference_mode = True
self.peft_config.save_pretrained(save_directory)

for param in self.parameters():
param.requires_grad = False # freeze the model

# save only the trainable weights
output_state_dict = get_peft_model_state_dict(self, kwargs.get("state_dict", None))
torch.save(output_state_dict, os.path.join(save_directory, WEIGHTS_NAME))

@classmethod
def from_pretrained(cls, model, model_id, **kwargs):
r"""
Args:
Instantiate a `LoraModel` from a pretrained Lora configuration and weights.
model (`transformers.PreTrainedModel`):
The model to be adapted. The model should be initialized with the `from_pretrained` method. from
`transformers` library.
model_id (`str`):
The name of the Lora configuration to use. Can be either:
- A string, the `model id` of a Lora configuration hosted inside a model repo on
huggingface Hub
- A path to a directory containing a Lora configuration file saved using the
`save_pretrained` method, e.g., ``./my_lora_config_directory/``.
"""
from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING

# load the config
config = PEFT_TYPE_TO_CONFIG_MAPPING[PeftConfig.from_pretrained(model_id).peft_type].from_pretrained(model_id)

model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](model, config)

# load weights if any
if os.path.exists(os.path.join(model_id, WEIGHTS_NAME)):
filename = os.path.join(model_id, WEIGHTS_NAME)
else:
try:
filename = hf_hub_download(model_id, WEIGHTS_NAME)
except: # noqa
raise ValueError(
f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub. "
f"Please check that the file {WEIGHTS_NAME} is present at {model_id}."
)

adapters_weights = torch.load(filename)
# load the weights into the model
return set_peft_model_state_dict(model, adapters_weights)

def _setup_prompt_encoder(self):
num_transformer_submodules = 0
transformer_backbone = None
Expand Down
1 change: 0 additions & 1 deletion src/peft/tuners/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import math
import warnings
Expand Down
1 change: 1 addition & 0 deletions src/peft/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .adapters_utils import CONFIG_NAME, WEIGHTS_NAME
from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType
from .other import _set_trainable, bloom_model_postprocess_past_key_value, shift_tokens_right, transpose
from .save_and_load import get_peft_model_state_dict, peft_model_load_and_dispatch, set_peft_model_state_dict
18 changes: 18 additions & 0 deletions src/peft/utils/adapters_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
WEIGHTS_NAME = "adapter_model.bin"
CONFIG_NAME = "adapter_config.json"

# TODO: add automapping and superclass here?
101 changes: 98 additions & 3 deletions src/peft/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
from dataclasses import dataclass, field
import json
import os
from dataclasses import asdict, dataclass, field
from typing import Optional, Union

from transformers.utils import PushToHubMixin

from huggingface_hub import hf_hub_download

from .adapters_utils import CONFIG_NAME


class PeftType(str, enum.Enum):
PROMPT_TUNING = "PROMPT_TUNING"
Expand All @@ -33,7 +40,94 @@ class TaskType(str, enum.Enum):


@dataclass
class PeftConfig:
class PeftConfigMixin(PushToHubMixin):
r"""
This is the base configuration class for PEFT adapter models. It contains all the methods that are common to all
PEFT adapter models. This class inherits from `transformers.utils.PushToHubMixin` which contains the methods to
push your model to the Hub. The method `save_pretrained` will save the configuration of your adapter model in a
directory. The method `from_pretrained` will load the configuration of your adapter model from a directory.

Args:
peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use.
"""
peft_type: Optional[PeftType] = field(default=None, metadata={"help": "The type of PEFT model."})

@property
def __dict__(self):
return asdict(self)

def to_dict(self):
return self.__dict__

def save_pretrained(self, save_directory, **kwargs):
r"""
This method saves the configuration of your adapter model in a directory.

Args:
save_directory (`str`):
The directory where the configuration will be saved.
**kwargs:
Additional keyword arguments passed along to the `transformers.utils.PushToHubMixin.push_to_hub`
method.
"""
if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")

os.makedirs(save_directory, exist_ok=True)

output_dict = self.__dict__
output_path = os.path.join(save_directory, CONFIG_NAME)

# save it
with open(output_path, "w") as writer:
writer.write(json.dumps(output_dict, indent=2, sort_keys=True))

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r"""
This method loads the configuration of your adapter model from a directory.

Args:
pretrained_model_name_or_path (`str`):
The directory or the hub-id where the configuration is saved.
**kwargs:
Additional keyword arguments passed along to the child class initialization.
"""
if os.path.isfile(os.path.join(pretrained_model_name_or_path, CONFIG_NAME)):
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
else:
try:
config_file = hf_hub_download(pretrained_model_name_or_path, CONFIG_NAME)
except:
raise ValueError(f"Can't find config.json at '{pretrained_model_name_or_path}'")

loaded_attributes = cls.from_json_file(config_file)

config = cls(**kwargs)

for key, value in loaded_attributes.items():
if hasattr(config, key):
setattr(config, key, value)

return config

@classmethod
def from_json_file(cls, path_json_file, **kwargs):
r"""
Loads a configuration file from a json file.

Args:
path_json_file (`str`):
The path to the json file.
"""
with open(path_json_file, "r") as file:
json_object = json.load(file)

return json_object


younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
@dataclass
class PeftConfig(PeftConfigMixin):
"""
This is the base configuration class to store the configuration of a :class:`~peft.PeftModel`.

Expand All @@ -43,6 +137,7 @@ class PeftConfig:
inference_mode (`bool`, defaults to `False`): Whether to use the Peft model in inference mode.
"""

base_model_name_or_path: str = field(default=None, metadata={"help": "The name of the base model to use."})
peft_type: Union[str, PeftType] = field(default=None, metadata={"help": "Peft type"})
task_type: Union[str, TaskType] = field(default=None, metadata={"help": "Task type"})
inference_mode: bool = field(default=False, metadata={"help": "Whether to use inference mode"})
Expand Down
98 changes: 98 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import tempfile
import os

from peft import LoraConfig, PromptEncoderConfig, PrefixTuningConfig, PromptTuningConfig

class PeftConfigTestMixin:
all_config_classes = (
LoraConfig,
PromptEncoderConfig,
PrefixTuningConfig,
PromptTuningConfig,
)


class PeftConfigTester(unittest.TestCase, PeftConfigTestMixin):
def test_methods(self):
r"""
Test if all configs have the expected methods. Here we test
- to_dict
- save_pretrained
- from_pretrained
- from_json_file
"""
# test if all configs have the expected methods
for config_class in self.all_config_classes:
config = config_class()
self.assertTrue(hasattr(config, "to_dict"))
self.assertTrue(hasattr(config, "save_pretrained"))
self.assertTrue(hasattr(config, "from_pretrained"))
self.assertTrue(hasattr(config, "from_json_file"))

def test_task_type(self):
for config_class in self.all_config_classes:
# assert this will not fail
_ = config_class(task_type="test")


def test_save_pretrained(self):
r"""
Test if the config is correctly saved and loaded using
- save_pretrained
"""
for config_class in self.all_config_classes:
config = config_class()
with tempfile.TemporaryDirectory() as tmp_dirname:
config.save_pretrained(tmp_dirname)

config_from_pretrained = config_class.from_pretrained(tmp_dirname)
self.assertEqual(config.to_dict(), config_from_pretrained.to_dict())

def test_from_json_file(self):
for config_class in self.all_config_classes:
config = config_class()
with tempfile.TemporaryDirectory() as tmp_dirname:
config.save_pretrained(tmp_dirname)

config_from_json = config_class.from_json_file(os.path.join(tmp_dirname, "adapter_config.json"))
self.assertEqual(config.to_dict(), config_from_json)


def test_to_dict(self):
r"""
Test if the config can be correctly converted to a dict using:
- to_dict
- __dict__
"""
for config_class in self.all_config_classes:
config = config_class()
self.assertEqual(config.to_dict(), config.__dict__)
self.assertTrue(isinstance(config.to_dict(), dict))


def test_set_attributes(self):
# manually set attributes and check if they are correctly written
for config_class in self.all_config_classes:
config = config_class(peft_type="test")

# save pretrained
with tempfile.TemporaryDirectory() as tmp_dirname:
config.save_pretrained(tmp_dirname)

config_from_pretrained = config_class.from_pretrained(tmp_dirname)
self.assertEqual(config.to_dict(), config_from_pretrained.to_dict())