Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Add hub utils #39

Merged
merged 7 commits into from
Jan 30, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
104 changes: 99 additions & 5 deletions src/peft/tuners/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import importlib
import math
import warnings
Expand All @@ -25,7 +25,10 @@
import torch.nn.functional as F
from transformers.pytorch_utils import Conv1D

from ..utils import PeftConfig, PeftType, transpose
from huggingface_hub import hf_hub_download

from transformers.utils import PushToHubMixin
from ..utils import PeftConfig, PeftType, transpose, WEIGHTS_NAME, CONFIG_NAME, get_peft_model_state_dict


def is_loralib_available():
Expand All @@ -36,6 +39,17 @@ def is_loralib_available():
import loralib as lora # noqa: F401
from loralib import mark_only_lora_as_trainable

MODEL_CARD_TEMPLATE = """---
license: apache-2.0
base_model: {base_model}
tags:
- peft
- lora
---
# Lora adapters for {model_name}

"""


@dataclass
class LoraConfig(PeftConfig):
Expand Down Expand Up @@ -72,7 +86,7 @@ def __post_init__(self):
self.peft_type = PeftType.LORA


class LoraModel(torch.nn.Module):
class LoraModel(PushToHubMixin, torch.nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any specific advantage to using transformers.PushToHubMixin vs huggingface_hub's? the latter means you only have to implement a single _save_pretrained() method, but perhaps you need more flexibility

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was not aware that the class you propose is simpler, I will have a look now and potentially replace it !

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried it and in both cases I needed to re-implement from_pretrained method, also it seems that push_to_hub from PushToHubMixin first clones a repo with the weights before pushing on the Hub, maybe I'll just stick on transformers.PushToHubMixin for now :D

"""
Creates Low Rank Adapter (Lora) model from a pretrained transformers model.

Expand All @@ -85,8 +99,9 @@ class LoraModel(torch.nn.Module):

Example::

>>> from transformers import AutoModelForSeq2SeqLM, LoraConfig >>> from peft import LoraModel, LoraConfig >>>
config = LoraConfig(
>>> from transformers import AutoModelForSeq2SeqLM, LoraConfig
>>> from peft import LoraModel, LoraConfig
>>> config = LoraConfig(
peft_type="LORA", task_type="SEQ_2_SEQ_LM", r=8, lora_alpha=32, target_modules=["q", "v"],
lora_dropout=0.01, )
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> lora_model = LoraModel(config, model)
Expand Down Expand Up @@ -147,6 +162,83 @@ def _replace_module(self, parent_module, child_name, new_module, old_module):
if old_module.bias is not None:
new_module.bias = old_module.bias

def save_pretrained(self, save_directory, **kwargs):
r"""
This function saves the adapter model and the adapter configuration files to a directory, so that it
can be re-loaded using the `LoraModel.from_pretrained` class method, and also used by the
`LoraModel.push_to_hub` method.

Args:
save_directory (`str`):
Directory where the adapter model and configuration files will be saved (will be created if it does not
exist).
**kwargs:
Additional keyword arguments passed along to the `push_to_hub` method.
"""
if os.path.isfile(save_directory):
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)

# save the config
self.peft_config.save_pretrained(save_directory)

for param in self.parameters():
param.requires_grad = False # freeze the model

# save only the trainable weights
output_state_dict = get_peft_model_state_dict(self)
torch.save(output_state_dict, os.path.join(save_directory, WEIGHTS_NAME))

# save model card
if 'name_or_path' in self.model.__dict__:
model_name = self.model.__dict__['name_or_path']
else:
model_name = None
model_card_content = MODEL_CARD_TEMPLATE.format(model_name=model_name, base_model=model_name)
with open(os.path.join(save_directory, "README.md"), "w", encoding="utf-8") as f:
f.write(model_card_content)

@classmethod
def from_pretrained(cls, model, lora_id, **kwargs):
r"""
Instantiate a `LoraModel` from a pretrained Lora configuration and weights.

Args:
model (`transformers.PreTrainedModel`):
The model to be adapted. The model should be initialized with the `from_pretrained` method.
from `transformers` library.
lora_id (`str`):
The name of the Lora configuration to use. Can be either:
- A string, the `model id` of a Lora configuration hosted inside a model repo on
huggingface Hub
- A path to a directory containing a Lora configuration file saved using the
`save_pretrained` method, e.g., ``./my_lora_config_directory/``.
"""
# load the config
config = LoraConfig.from_pretrained(lora_id)

model = cls(config, model)

# load weights if any
if os.path.exists(os.path.join(lora_id, WEIGHTS_NAME)):
filename = os.path.join(lora_id, WEIGHTS_NAME)
else:
try:
filename = hf_hub_download(lora_id, WEIGHTS_NAME)
except: # noqa
raise ValueError(
f"Can't find weights for {lora_id} in {lora_id} or in the Hugging Face Hub. "
f"Please check that the file {WEIGHTS_NAME} is present at {lora_id}."
)


adapters_weights = torch.load(filename)
# load the weights into the model
model.load_state_dict(adapters_weights, strict=False)

return model


def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)

Expand All @@ -168,6 +260,8 @@ def get_peft_config_as_dict(self, inference: bool = False):
return config




# Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# and modified to work with PyTorch FSDP

Expand Down
1 change: 1 addition & 0 deletions src/peft/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType
from .other import _set_trainable, bloom_model_postprocess_past_key_value, shift_tokens_right, transpose
from .save_and_load import get_peft_model_state_dict, peft_model_load_and_dispatch, set_peft_model_state_dict
from .adapters_utils import WEIGHTS_NAME, CONFIG_NAME
18 changes: 18 additions & 0 deletions src/peft/utils/adapters_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
WEIGHTS_NAME = "adapter_model.bin"
CONFIG_NAME = "adapter_config.json"

# TODO: add automapping and superclass here?
99 changes: 94 additions & 5 deletions src/peft/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import json
import enum
from dataclasses import dataclass, field
from dataclasses import dataclass, field, asdict
from typing import Optional, Union

from transformers.utils import PushToHubMixin
from huggingface_hub import hf_hub_download

from .adapters_utils import CONFIG_NAME

class PeftType(str, enum.Enum):
PROMPT_TUNING = "PROMPT_TUNING"
Expand All @@ -33,7 +38,93 @@ class TaskType(str, enum.Enum):


@dataclass
class PeftConfig:
class PeftConfigMixin(PushToHubMixin):
r"""
This is the base configuration class for PEFT adapter models. It contains all the methods that
are common to all PEFT adapter models.
This class inherits from `transformers.utils.PushToHubMixin` which contains the methods to push your model to the Hub.
The method `save_pretrained` will save the configuration of your adapter model in a directory.
The method `from_pretrained` will load the configuration of your adapter model from a directory.

Args:
peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use.
"""
peft_type: Optional[PeftType] = field(default=None, metadata={"help": "The type of PEFT model."})

@property
def __dict__(self):
return asdict(self)

def to_dict(self):
return self.__dict__

def save_pretrained(self, save_directory, **kwargs):
r"""
This method saves the configuration of your adapter model in a directory.

Args:
save_directory (`str`):
The directory where the configuration will be saved.
**kwargs:
Additional keyword arguments passed along to the `transformers.utils.PushToHubMixin.push_to_hub` method.
"""
if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")

os.makedirs(save_directory, exist_ok=True)

output_dict = self.__dict__
output_path = os.path.join(save_directory, CONFIG_NAME)

# save it
with open(output_path, "w") as writer:
writer.write(json.dumps(output_dict, indent=2, sort_keys=True))

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r"""
This method loads the configuration of your adapter model from a directory.

Args:
pretrained_model_name_or_path (`str`):
The directory or the hub-id where the configuration is saved.
**kwargs:
Additional keyword arguments passed along to the child class initialization.
"""
if os.path.isfile(os.path.join(pretrained_model_name_or_path, CONFIG_NAME)):
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
else:
try:
config_file = hf_hub_download(pretrained_model_name_or_path, CONFIG_NAME)
except:
raise ValueError(f"Can't find config.json at '{pretrained_model_name_or_path}'")

loaded_attributes = cls.from_json_file(config_file)

config = cls(**kwargs)

for key, value in loaded_attributes.items():
if hasattr(config, key):
setattr(config, key, value)

return config

@classmethod
def from_json_file(cls, path_json_file, **kwargs):
r"""
Loads a configuration file from a json file.

Args:
path_json_file (`str`):
The path to the json file.
"""
with open(path_json_file, 'r') as file:
json_object = json.load(file)

return json_object


younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
class PeftConfig(PeftConfigMixin):
"""
This is the base configuration class to store the configuration of a :class:`~peft.PeftModel`.

Expand All @@ -42,13 +133,11 @@ class PeftConfig:
task_type (Union[[`~peft.utils.config.TaskType`], `str`]): The type of task to perform.
inference_mode (`bool`, defaults to `False`): Whether to use the Peft model in inference mode.
"""

peft_type: Union[str, PeftType] = field(default=None, metadata={"help": "Peft type"})
task_type: Union[str, TaskType] = field(default=None, metadata={"help": "Task type"})
inference_mode: bool = field(default=False, metadata={"help": "Whether to use inference mode"})


@dataclass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is required else similar error as seen in the previous comment

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! In any case I added tests to test that the script you shared above will not fail

class PromptLearningConfig(PeftConfig):
"""
This is the base configuration class to store the configuration of a Union[[`~peft.PrefixTuning`],
Expand Down