Skip to content

Commit

Permalink
Merge branch 'main' into generate-inference-type
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin committed Feb 27, 2024
2 parents 4fefa78 + 247b9fd commit e25bcaf
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 46 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def get_version() -> str:

extras["torch"] = [
"torch",
"safetensors",
]
extras["hf_transfer"] = [
"hf_transfer>=0.1.4", # Pin for progress bars
Expand Down
19 changes: 2 additions & 17 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1474,9 +1474,6 @@ def list_models(
>>> # List only the text classification models
>>> api.list_models(filter="text-classification")
>>> # Using the `ModelFilter`
>>> filt = ModelFilter(task="text-classification")
>>> # List only models from the AllenNLP library
>>> api.list_models(filter="allennlp")
Expand All @@ -1500,7 +1497,6 @@ def list_models(
raise ValueError("`emissions_thresholds` were passed without setting `cardData=True`.")

path = f"{self.endpoint}/api/models"
model_str = ""
headers = self._build_hf_headers(token=token)
params = {}
filter_list = []
Expand All @@ -1515,10 +1511,9 @@ def list_models(

# Build the filter list
if author:
model_str = f"{author}/"
params.update({"author": author})
if model_name:
model_str += model_name
params.update({"search": model_name})
if library:
filter_list.extend([library] if isinstance(library, str) else library)
if task:
Expand All @@ -1534,8 +1529,6 @@ def list_models(
filter_list.extend([language] if isinstance(language, str) else language)
if tags:
filter_list.extend([tags] if isinstance(tags, str) else tags)
if model_str:
params.update({"search": model_str})

if search:
params.update({"search": search})
Expand Down Expand Up @@ -1717,16 +1710,12 @@ def list_datasets(
>>> # List only the text classification datasets
>>> api.list_datasets(filter="task_categories:text-classification")
>>> # Using the `DatasetFilter`
>>> filt = DatasetFilter(task_categories="text-classification")
>>> # List only the datasets in russian for language modeling
>>> api.list_datasets(
... filter=("language:ru", "task_ids:language-modeling")
... )
>>> # Using the `DatasetFilter`
>>> filt = DatasetFilter(language="ru", task_ids="language-modeling")
>>> api.list_datasets(filter=filt)
```
Expand All @@ -1746,7 +1735,6 @@ def list_datasets(
```
"""
path = f"{self.endpoint}/api/datasets"
dataset_str = ""
headers = self._build_hf_headers(token=token)
params = {}
filter_list = []
Expand All @@ -1759,10 +1747,9 @@ def list_datasets(

# Build the filter list
if author:
dataset_str = f"{author}/"
params.update({"author": author})
if dataset_name:
dataset_str += dataset_name
params.update({"search": dataset_name})

for attr in (
benchmark,
Expand All @@ -1780,8 +1767,6 @@ def list_datasets(
if not data.startswith(f"{attr}:"):
data = f"{attr}:{data}"
filter_list.append(data)
if dataset_str:
params.update({"search": dataset_str})

if search:
params.update({"search": search})
Expand Down
74 changes: 57 additions & 17 deletions src/huggingface_hub/hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,18 @@
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Type, TypeVar, Union, get_args

from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
from .file_download import hf_hub_download, is_torch_available
from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE
from .file_download import hf_hub_download
from .hf_api import HfApi
from .utils import HfHubHTTPError, SoftTemporaryDirectory, logging, validate_hf_hub_args
from .utils import (
EntryNotFoundError,
HfHubHTTPError,
SoftTemporaryDirectory,
is_safetensors_available,
is_torch_available,
logging,
validate_hf_hub_args,
)
from .utils._deprecation import _deprecate_arguments


Expand All @@ -18,6 +26,11 @@
if is_torch_available():
import torch # type: ignore

if is_safetensors_available():
from safetensors import safe_open
from safetensors.torch import save_file


logger = logging.get_logger(__name__)

# Generic variable that is either ModelHubMixin or a subclass thereof
Expand Down Expand Up @@ -447,7 +460,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
def _save_pretrained(self, save_directory: Path) -> None:
"""Save weights from a Pytorch model to a local directory."""
model_to_save = self.module if hasattr(self, "module") else self # type: ignore
torch.save(model_to_save.state_dict(), save_directory / PYTORCH_WEIGHTS_NAME)
save_file(model_to_save.state_dict(), save_directory / SAFETENSORS_SINGLE_FILE)

@classmethod
def _from_pretrained(
Expand All @@ -466,25 +479,52 @@ def _from_pretrained(
**model_kwargs,
):
"""Load Pytorch pretrained weights and return the loaded model."""
model = cls(**model_kwargs)
if os.path.isdir(model_id):
print("Loading weights from local directory")
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
return cls._load_as_safetensor(model, model_file, map_location, strict)
else:
model_file = hf_hub_download(
repo_id=model_id,
filename=PYTORCH_WEIGHTS_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
model = cls(**model_kwargs)
try:
model_file = hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
return cls._load_as_safetensor(model, model_file, map_location, strict)
except EntryNotFoundError:
model_file = hf_hub_download(
repo_id=model_id,
filename=PYTORCH_WEIGHTS_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
return cls._load_as_pickle(model, model_file, map_location, strict)

@classmethod
def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
state_dict = torch.load(model_file, map_location=torch.device(map_location))
model.load_state_dict(state_dict, strict=strict) # type: ignore
model.eval() # type: ignore
return model

@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
state_dict = {}
with safe_open(model_file, framework="pt", device=map_location) as f: # type: ignore [attr-defined]
for k in f.keys():
state_dict[k] = f.get_tensor(k)
model.load_state_dict(state_dict, strict=strict) # type: ignore
model.eval() # type: ignore
return model
7 changes: 4 additions & 3 deletions src/huggingface_hub/inference/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,17 @@ class ModelStatus:
backend. Loadable models are automatically loaded when the user first
requests inference on the endpoint. This means it is transparent for the
user to load a model, except that the first call takes longer to complete.
compute_type (`str`):
The type of compute resource the model is using or will use, such as 'gpu' or 'cpu'.
compute_type (`Dict`):
Information about the compute resource the model is using or will use, such as 'gpu' type and number of
replicas.
framework (`str`):
The name of the framework that the model was built with, such as 'transformers'
or 'text-generation-inference'.
"""

loaded: bool
state: str
compute_type: str
compute_type: Dict
framework: str


Expand Down
1 change: 1 addition & 0 deletions src/huggingface_hub/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
is_pillow_available,
is_pydantic_available,
is_pydot_available,
is_safetensors_available,
is_tensorboard_available,
is_tf_available,
is_torch_available,
Expand Down
6 changes: 6 additions & 0 deletions src/huggingface_hub/utils/_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"pillow": {"Pillow"},
"pydantic": {"pydantic"},
"pydot": {"pydot"},
"safetensors": {"safetensors"},
"tensorboard": {"tensorboardX"},
"tensorflow": (
"tensorflow",
Expand Down Expand Up @@ -229,6 +230,11 @@ def get_torch_version() -> str:
return _get_version("torch")


# Safetensors
def is_safetensors_available() -> bool:
return is_package_available("safetensors")


# Shell-related helpers
try:
# Set to `True` if script is running in a Google Colab notebook.
Expand Down
84 changes: 80 additions & 4 deletions tests/test_hub_mixin_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
import struct
import unittest
from pathlib import Path
from typing import TypeVar
Expand All @@ -8,14 +9,16 @@
import pytest

from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.hub_mixin import PyTorchModelHubMixin
from huggingface_hub.utils import SoftTemporaryDirectory, is_torch_available
from huggingface_hub.constants import PYTORCH_WEIGHTS_NAME
from huggingface_hub.hub_mixin import ModelHubMixin, PyTorchModelHubMixin
from huggingface_hub.utils import EntryNotFoundError, HfHubHTTPError, SoftTemporaryDirectory, is_torch_available

from .testing_constants import ENDPOINT_STAGING, TOKEN, USER
from .testing_utils import repo_name, requires


if is_torch_available():
import torch
import torch.nn as nn

CONFIG = {"num": 10, "act": "gelu_fast"}
Expand Down Expand Up @@ -48,16 +51,26 @@ def setUpClass(cls):
def test_save_pretrained_basic(self):
DummyModel().save_pretrained(self.cache_dir)
files = os.listdir(self.cache_dir)
self.assertTrue("pytorch_model.bin" in files)
self.assertTrue("model.safetensors" in files)
self.assertEqual(len(files), 1)

def test_save_pretrained_with_config(self):
DummyModel().save_pretrained(self.cache_dir, config=CONFIG)
files = os.listdir(self.cache_dir)
self.assertTrue("config.json" in files)
self.assertTrue("pytorch_model.bin" in files)
self.assertTrue("model.safetensors" in files)
self.assertEqual(len(files), 2)

def test_save_as_safetensors(self):
DummyModel().save_pretrained(self.cache_dir, config=TOKEN)
modelFile = self.cache_dir / "model.safetensors"
# check for safetensors header to ensure we are saving the model in safetensors format
# while an implementation detail, assert as this has safety implications
# https://github.com/huggingface/safetensors?tab=readme-ov-file#format
with open(modelFile, "rb") as f:
header_size = struct.unpack("<Q", f.read(8))[0]
self.assertEqual(header_size, 128)

def test_save_pretrained_with_push_to_hub(self):
repo_id = repo_name("save")
save_directory = self.cache_dir / repo_id
Expand Down Expand Up @@ -85,6 +98,69 @@ def test_from_pretrained_model_id_only(self, from_pretrained_mock: Mock) -> None
from_pretrained_mock.assert_called_once()
self.assertIs(model, from_pretrained_mock.return_value)

def pretend_file_download(self, **kwargs):
if kwargs.get("filename") == "config.json":
raise HfHubHTTPError("no config")
DummyModel().save_pretrained(self.cache_dir)
return self.cache_dir / "model.safetensors"

@patch("huggingface_hub.hub_mixin.hf_hub_download")
def test_from_pretrained_model_from_hub_prefer_safetensor(self, hf_hub_download_mock: Mock) -> None:
hf_hub_download_mock.side_effect = self.pretend_file_download
model = DummyModel.from_pretrained("namespace/repo_name")
hf_hub_download_mock.assert_any_call(
repo_id="namespace/repo_name",
filename="model.safetensors",
revision=None,
cache_dir=None,
force_download=False,
proxies=None,
resume_download=False,
token=None,
local_files_only=False,
)
self.assertIsNotNone(model)

def pretend_file_download_fallback(self, **kwargs):
filename = kwargs.get("filename")
if filename == "model.safetensors" or filename == "config.json":
raise EntryNotFoundError("not found")

class TestMixin(ModelHubMixin):
def _save_pretrained(self, save_directory: Path) -> None:
torch.save(DummyModel().state_dict(), save_directory / PYTORCH_WEIGHTS_NAME)

TestMixin().save_pretrained(self.cache_dir)
return self.cache_dir / PYTORCH_WEIGHTS_NAME

@patch("huggingface_hub.hub_mixin.hf_hub_download")
def test_from_pretrained_model_from_hub_fallback_pickle(self, hf_hub_download_mock: Mock) -> None:
hf_hub_download_mock.side_effect = self.pretend_file_download_fallback
model = DummyModel.from_pretrained("namespace/repo_name")
hf_hub_download_mock.assert_any_call(
repo_id="namespace/repo_name",
filename="model.safetensors",
revision=None,
cache_dir=None,
force_download=False,
proxies=None,
resume_download=False,
token=None,
local_files_only=False,
)
hf_hub_download_mock.assert_any_call(
repo_id="namespace/repo_name",
filename="pytorch_model.bin",
revision=None,
cache_dir=None,
force_download=False,
proxies=None,
resume_download=False,
token=None,
local_files_only=False,
)
self.assertIsNotNone(model)

@patch.object(DummyModel, "_from_pretrained")
def test_from_pretrained_model_id_and_revision(self, from_pretrained_mock: Mock) -> None:
"""Regression test for #1313.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_inference_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ async def test_get_status_loaded_model() -> None:
model_status = await AsyncInferenceClient().get_model_status("bigscience/bloom")
assert model_status.loaded is True
assert model_status.state == "Loaded"
assert model_status.compute_type == "gpu"
assert isinstance(model_status.compute_type, dict) # e.g. {'gpu': {'gpu': 'a100', 'count': 8}}
assert model_status.framework == "text-generation-inference"


Expand Down
8 changes: 4 additions & 4 deletions tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,10 +592,10 @@ def test_too_big_model(self) -> None:
def test_loaded_model(self) -> None:
client = InferenceClient()
model_status = client.get_model_status("bigscience/bloom")
self.assertTrue(model_status.loaded)
self.assertEqual(model_status.state, "Loaded")
self.assertEqual(model_status.compute_type, "gpu")
self.assertEqual(model_status.framework, "text-generation-inference")
assert model_status.loaded
assert model_status.state == "Loaded"
assert isinstance(model_status.compute_type, dict) # e.g. {'gpu': {'gpu': 'a100', 'count': 8}}
assert model_status.framework == "text-generation-inference"

def test_unknown_model(self) -> None:
client = InferenceClient()
Expand Down

0 comments on commit e25bcaf

Please sign in to comment.