Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Support loading PEFT (LoRA) models
Browse files Browse the repository at this point in the history
  • Loading branch information
idoru committed Jun 14, 2023
1 parent e4dabb0 commit ef138a0
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 1 deletion.
1 change: 1 addition & 0 deletions basaran/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def is_true(value):
PORT = int(os.getenv("PORT", "80"))

# Model-related arguments:
MODEL_PEFT = is_true(os.getenv("MODEL_PEFT", ""))
MODEL_REVISION = os.getenv("MODEL_REVISION", "")
MODEL_CACHE_DIR = os.getenv("MODEL_CACHE_DIR", "models")
MODEL_LOAD_IN_8BIT = is_true(os.getenv("MODEL_LOAD_IN_8BIT", ""))
Expand Down
2 changes: 2 additions & 0 deletions basaran/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from . import MODEL_LOAD_IN_4BIT
from . import MODEL_4BIT_QUANT_TYPE
from . import MODEL_4BIT_DOUBLE_QUANT
from . import MODEL_PEFT
from . import MODEL_LOCAL_FILES_ONLY
from . import MODEL_TRUST_REMOTE_CODE
from . import MODEL_HALF_PRECISION
Expand All @@ -44,6 +45,7 @@
name_or_path=MODEL,
revision=MODEL_REVISION,
cache_dir=MODEL_CACHE_DIR,
is_peft=MODEL_PEFT,
load_in_8bit=MODEL_LOAD_IN_8BIT,
load_in_4bit=MODEL_LOAD_IN_4BIT,
quant_type=MODEL_4BIT_QUANT_TYPE,
Expand Down
12 changes: 11 additions & 1 deletion basaran/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
TopPLogitsWarper,
BitsAndBytesConfig
)
from peft import (
PeftConfig,
PeftModel
)

from .choice import map_choice
from .tokenizer import StreamTokenizer
Expand Down Expand Up @@ -310,6 +314,7 @@ def load_model(
name_or_path,
revision=None,
cache_dir=None,
is_peft=False,
load_in_8bit=False,
load_in_4bit=False,
quant_type="fp4",
Expand All @@ -327,7 +332,6 @@ def load_model(
kwargs["revision"] = revision
if cache_dir:
kwargs["cache_dir"] = cache_dir
tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs)

# Set device mapping and quantization options if CUDA is available.
if torch.cuda.is_available():
Expand All @@ -354,6 +358,12 @@ def load_model(
if half_precision or load_in_8bit or load_in_4bit:
kwargs["torch_dtype"] = torch.float16

if is_peft:
peft_config = PeftConfig.from_pretrained(name_or_path)
name_or_path = peft_config.base_model_name_or_path

tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs)

# Support both decoder-only and encoder-decoder models.
try:
model = AutoModelForCausalLM.from_pretrained(name_or_path, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ safetensors~=0.3.1
torch>=1.12.1
transformers[sentencepiece]~=4.30.1
waitress~=2.1.2
peft~=0.3.0

0 comments on commit ef138a0

Please sign in to comment.