Skip to content

Commit

Permalink
feat: import model for transformers framework (#4247)
Browse files Browse the repository at this point in the history
* add import_model to transformers

---------

Co-authored-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: Frost Ming <mianghong@gmail.com>
  • Loading branch information
3 people committed Nov 7, 2023
1 parent 1fed50c commit 3b8a9d2
Show file tree
Hide file tree
Showing 4 changed files with 374 additions and 38 deletions.
43 changes: 9 additions & 34 deletions src/bentoml/_internal/frameworks/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@

import logging
import os
import re
import shutil
import typing as t
from pathlib import Path
from typing import TYPE_CHECKING

import attr
Expand All @@ -18,6 +16,7 @@
from bentoml.models import ModelContext

from ..models.model import PartialKwargsModelOptions
from .utils.transformers import extract_commit_hash

if TYPE_CHECKING:
from types import ModuleType
Expand Down Expand Up @@ -169,33 +168,6 @@ def _str2cls(
return cls


def _extract_commit_hash(
resolved_dir: str, regex_commit_hash: t.Pattern[str]
) -> str | None:
"""
Extracts the commit hash from a resolved filename toward a cache file.
modified from https://github.com/huggingface/transformers/blob/0b7b4429c78de68acaf72224eb6dae43616d820c/src/transformers/utils/hub.py#L219
"""

resolved_dir = str(Path(resolved_dir).as_posix()) + "/"
search = re.search(r"snapshots/([^/]+)/", resolved_dir)

if search is None:
return None

commit_hash = search.groups()[0]
return commit_hash if regex_commit_hash.match(commit_hash) else None


def _try_import_huggingface_hub():
try:
import huggingface_hub # noqa: F401
except ImportError: # pragma: no cover
raise MissingDependencyException(
"'huggingface_hub' is required in order to download pretrained diffusion models, install with 'pip install huggingface-hub'. For more information, refer to https://huggingface.co/docs/huggingface_hub/quick-start",
)


def get(tag_like: str | Tag) -> bentoml.Model:
"""
Get the BentoML model with the given tag.
Expand Down Expand Up @@ -482,6 +454,12 @@ def import_model(

tag = Tag.from_taglike(name)

try:
model = bentoml.models.get(tag)
return model
except bentoml.exceptions.NotFound:
pass

if sync_with_hub_version:
if tag.version is not None:
logger.warn(
Expand Down Expand Up @@ -525,23 +503,20 @@ def import_model(
)

elif pipeline_class:
_try_import_huggingface_hub()

src_dir = pipeline_class.download(
model_name_or_path, proxies=proxies, revision=revision, variant=variant
)

if sync_with_hub_version:
from huggingface_hub.file_download import REGEX_COMMIT_HASH

version = _extract_commit_hash(src_dir, REGEX_COMMIT_HASH)
version = extract_commit_hash(src_dir, REGEX_COMMIT_HASH)
if version is not None:
if variant is not None:
version = version + "-" + variant
tag.version = version

else:
_try_import_huggingface_hub()
from huggingface_hub import snapshot_download

src_dir = snapshot_download(
Expand All @@ -553,7 +528,7 @@ def import_model(
if sync_with_hub_version:
from huggingface_hub.file_download import REGEX_COMMIT_HASH

version = _extract_commit_hash(src_dir, REGEX_COMMIT_HASH)
version = extract_commit_hash(src_dir, REGEX_COMMIT_HASH)
if version is not None:
tag.version = version

Expand Down
Loading

0 comments on commit 3b8a9d2

Please sign in to comment.