-
Notifications
You must be signed in to change notification settings - Fork 459
/
hub_mixin.py
788 lines (696 loc) · 33.8 KB
/
hub_mixin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
import inspect
import json
import os
from dataclasses import asdict, dataclass, is_dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, get_args
from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE
from .file_download import hf_hub_download
from .hf_api import HfApi
from .repocard import ModelCard, ModelCardData
from .utils import (
EntryNotFoundError,
HfHubHTTPError,
SoftTemporaryDirectory,
is_jsonable,
is_safetensors_available,
is_torch_available,
logging,
validate_hf_hub_args,
)
if TYPE_CHECKING:
from _typeshed import DataclassInstance
if is_torch_available():
import torch # type: ignore
if is_safetensors_available():
from safetensors.torch import load_model as load_model_as_safetensor
from safetensors.torch import save_model as save_model_as_safetensor
logger = logging.get_logger(__name__)
# Generic variable that is either ModelHubMixin or a subclass thereof
T = TypeVar("T", bound="ModelHubMixin")
# Generic variable to represent an args type
ARGS_T = TypeVar("ARGS_T")
ENCODER_T = Callable[[ARGS_T], Any]
DECODER_T = Callable[[Any], ARGS_T]
CODER_T = Tuple[ENCODER_T, DECODER_T]
DEFAULT_MODEL_CARD = """
---
# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
# Doc / guide: https://huggingface.co/docs/hub/model-cards
{{ card_data }}
---
This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
- Library: {{ repo_url | default("[More Information Needed]", true) }}
- Docs: {{ docs_url | default("[More Information Needed]", true) }}
"""
@dataclass
class MixinInfo:
model_card_template: str
model_card_data: ModelCardData
repo_url: Optional[str] = None
docs_url: Optional[str] = None
class ModelHubMixin:
"""
A generic mixin to integrate ANY machine learning framework with the Hub.
To integrate your framework, your model class must inherit from this class. Custom logic for saving/loading models
have to be overwritten in [`_from_pretrained`] and [`_save_pretrained`]. [`PyTorchModelHubMixin`] is a good example
of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions.
When inheriting from [`ModelHubMixin`], you can define class-level attributes. These attributes are not passed to
`__init__` but to the class definition itself. This is useful to define metadata about the library integrating
[`ModelHubMixin`].
For more details on how to integrate the mixin with your library, checkout the [integration guide](../guides/integrations).
Args:
repo_url (`str`, *optional*):
URL of the library repository. Used to generate model card.
docs_url (`str`, *optional*):
URL of the library documentation. Used to generate model card.
model_card_template (`str`, *optional*):
Template of the model card. Used to generate model card. Defaults to a generic template.
languages (`List[str]`, *optional*):
Languages supported by the library. Used to generate model card.
library_name (`str`, *optional*):
Name of the library integrating ModelHubMixin. Used to generate model card.
license (`str`, *optional*):
License of the library integrating ModelHubMixin. Used to generate model card.
E.g: "apache-2.0"
license_name (`str`, *optional*):
Name of the library integrating ModelHubMixin. Used to generate model card.
Only used if `license` is set to `other`.
E.g: "coqui-public-model-license".
license_link (`str`, *optional*):
URL to the license of the library integrating ModelHubMixin. Used to generate model card.
Only used if `license` is set to `other` and `license_name` is set.
E.g: "https://coqui.ai/cpml".
pipeline_tag (`str`, *optional*):
Tag of the pipeline. Used to generate model card. E.g. "text-classification".
tags (`List[str]`, *optional*):
Tags to be added to the model card. Used to generate model card. E.g. ["x-custom-tag", "arxiv:2304.12244"]
coders (`Dict[Type, Tuple[Callable, Callable]]`, *optional*):
Dictionary of custom types and their encoders/decoders. Used to encode/decode arguments that are not
jsonable by default. E.g dataclasses, argparse.Namespace, OmegaConf, etc.
Example:
```python
>>> from huggingface_hub import ModelHubMixin
# Inherit from ModelHubMixin
>>> class MyCustomModel(
... ModelHubMixin,
... library_name="my-library",
... tags=["x-custom-tag", "arxiv:2304.12244"],
... repo_url="https://github.com/huggingface/my-cool-library",
... docs_url="https://huggingface.co/docs/my-cool-library",
... # ^ optional metadata to generate model card
... ):
... def __init__(self, size: int = 512, device: str = "cpu"):
... # define how to initialize your model
... super().__init__()
... ...
...
... def _save_pretrained(self, save_directory: Path) -> None:
... # define how to serialize your model
... ...
...
... @classmethod
... def from_pretrained(
... cls: Type[T],
... pretrained_model_name_or_path: Union[str, Path],
... *,
... force_download: bool = False,
... resume_download: Optional[bool] = None,
... proxies: Optional[Dict] = None,
... token: Optional[Union[str, bool]] = None,
... cache_dir: Optional[Union[str, Path]] = None,
... local_files_only: bool = False,
... revision: Optional[str] = None,
... **model_kwargs,
... ) -> T:
... # define how to deserialize your model
... ...
>>> model = MyCustomModel(size=256, device="gpu")
# Save model weights to local directory
>>> model.save_pretrained("my-awesome-model")
# Push model weights to the Hub
>>> model.push_to_hub("my-awesome-model")
# Download and initialize weights from the Hub
>>> reloaded_model = MyCustomModel.from_pretrained("username/my-awesome-model")
>>> reloaded_model.size
256
# Model card has been correctly populated
>>> from huggingface_hub import ModelCard
>>> card = ModelCard.load("username/my-awesome-model")
>>> card.data.tags
["x-custom-tag", "pytorch_model_hub_mixin", "model_hub_mixin"]
>>> card.data.library_name
"my-library"
```
"""
_hub_mixin_config: Optional[Union[dict, "DataclassInstance"]] = None
# ^ optional config attribute automatically set in `from_pretrained`
_hub_mixin_info: MixinInfo
# ^ information about the library integrating ModelHubMixin (used to generate model card)
_hub_mixin_inject_config: bool # whether `_from_pretrained` expects `config` or not
_hub_mixin_init_parameters: Dict[str, inspect.Parameter] # __init__ parameters
_hub_mixin_jsonable_default_values: Dict[str, Any] # default values for __init__ parameters
_hub_mixin_jsonable_custom_types: Tuple[Type, ...] # custom types that can be encoded/decoded
_hub_mixin_coders: Dict[Type, CODER_T] # encoders/decoders for custom types
# ^ internal values to handle config
def __init_subclass__(
cls,
*,
# Generic info for model card
repo_url: Optional[str] = None,
docs_url: Optional[str] = None,
# Model card template
model_card_template: str = DEFAULT_MODEL_CARD,
# Model card metadata
languages: Optional[List[str]] = None,
library_name: Optional[str] = None,
license: Optional[str] = None,
license_name: Optional[str] = None,
license_link: Optional[str] = None,
pipeline_tag: Optional[str] = None,
tags: Optional[List[str]] = None,
# How to encode/decode arguments with custom type into a JSON config?
coders: Optional[
Dict[Type, CODER_T]
# Key is a type.
# Value is a tuple (encoder, decoder).
# Example: {MyCustomType: (lambda x: x.value, lambda data: MyCustomType(data))}
] = None,
) -> None:
"""Inspect __init__ signature only once when subclassing + handle modelcard."""
super().__init_subclass__()
# Will be reused when creating modelcard
tags = tags or []
tags.append("model_hub_mixin")
cls._hub_mixin_info = MixinInfo(
model_card_template=model_card_template,
repo_url=repo_url,
docs_url=docs_url,
model_card_data=ModelCardData(
languages=languages,
library_name=library_name,
license=license,
license_name=license_name,
license_link=license_link,
pipeline_tag=pipeline_tag,
tags=tags,
),
)
# Handle encoders/decoders for args
cls._hub_mixin_coders = coders or {}
cls._hub_mixin_jsonable_custom_types = tuple(cls._hub_mixin_coders.keys())
# Inspect __init__ signature to handle config
cls._hub_mixin_init_parameters = dict(inspect.signature(cls.__init__).parameters)
cls._hub_mixin_jsonable_default_values = {
param.name: cls._encode_arg(param.default)
for param in cls._hub_mixin_init_parameters.values()
if param.default is not inspect.Parameter.empty and cls._is_jsonable(param.default)
}
cls._hub_mixin_inject_config = "config" in inspect.signature(cls._from_pretrained).parameters
def __new__(cls, *args, **kwargs) -> "ModelHubMixin":
"""Create a new instance of the class and handle config.
3 cases:
- If `self._hub_mixin_config` is already set, do nothing.
- If `config` is passed as a dataclass, set it as `self._hub_mixin_config`.
- Otherwise, build `self._hub_mixin_config` from default values and passed values.
"""
instance = super().__new__(cls)
# If `config` is already set, return early
if instance._hub_mixin_config is not None:
return instance
# Infer passed values
passed_values = {
**{
key: value
for key, value in zip(
# [1:] to skip `self` parameter
list(cls._hub_mixin_init_parameters)[1:],
args,
)
},
**kwargs,
}
# If config passed as dataclass => set it and return early
if is_dataclass(passed_values.get("config")):
instance._hub_mixin_config = passed_values["config"]
return instance
# Otherwise, build config from default + passed values
init_config = {
# default values
**cls._hub_mixin_jsonable_default_values,
# passed values
**{
key: cls._encode_arg(value) # Encode custom types as jsonable value
for key, value in passed_values.items()
if instance._is_jsonable(value) # Only if jsonable or we have a custom encoder
},
}
init_config.pop("config", {})
# Populate `init_config` with provided config
provided_config = passed_values.get("config")
if isinstance(provided_config, dict):
init_config.update(provided_config)
# Set `config` attribute and return
if init_config != {}:
instance._hub_mixin_config = init_config
return instance
@classmethod
def _is_jsonable(cls, value: Any) -> bool:
"""Check if a value is JSON serializable."""
if isinstance(value, cls._hub_mixin_jsonable_custom_types):
return True
return is_jsonable(value)
@classmethod
def _encode_arg(cls, arg: Any) -> Any:
"""Encode an argument into a JSON serializable format."""
for type_, (encoder, _) in cls._hub_mixin_coders.items():
if isinstance(arg, type_):
return encoder(arg)
return arg
@classmethod
def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> ARGS_T:
"""Decode a JSON serializable value into an argument."""
for type_, (_, decoder) in cls._hub_mixin_coders.items():
if issubclass(expected_type, type_):
return decoder(value)
return value
def save_pretrained(
self,
save_directory: Union[str, Path],
*,
config: Optional[Union[dict, "DataclassInstance"]] = None,
repo_id: Optional[str] = None,
push_to_hub: bool = False,
**push_to_hub_kwargs,
) -> Optional[str]:
"""
Save weights in local directory.
Args:
save_directory (`str` or `Path`):
Path to directory in which the model weights and configuration will be saved.
config (`dict` or `DataclassInstance`, *optional*):
Model configuration specified as a key/value dictionary or a dataclass instance.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Huggingface Hub after saving it.
repo_id (`str`, *optional*):
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
not provided.
kwargs:
Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
Returns:
`str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise.
"""
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)
# Remove config.json if already exists. After `_save_pretrained` we don't want to overwrite config.json
# as it might have been saved by the custom `_save_pretrained` already. However we do want to overwrite
# an existing config.json if it was not saved by `_save_pretrained`.
config_path = save_directory / CONFIG_NAME
config_path.unlink(missing_ok=True)
# save model weights/files (framework-specific)
self._save_pretrained(save_directory)
# save config (if provided and if not serialized yet in `_save_pretrained`)
if config is None:
config = self._hub_mixin_config
if config is not None:
if is_dataclass(config):
config = asdict(config) # type: ignore[arg-type]
if not config_path.exists():
config_str = json.dumps(config, sort_keys=True, indent=2)
config_path.write_text(config_str)
# save model card
model_card_path = save_directory / "README.md"
if not model_card_path.exists(): # do not overwrite if already exists
self.generate_model_card().save(save_directory / "README.md")
# push to the Hub if required
if push_to_hub:
kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input
if config is not None: # kwarg for `push_to_hub`
kwargs["config"] = config
if repo_id is None:
repo_id = save_directory.name # Defaults to `save_directory` name
return self.push_to_hub(repo_id=repo_id, **kwargs)
return None
def _save_pretrained(self, save_directory: Path) -> None:
"""
Overwrite this method in subclass to define how to save your model.
Check out our [integration guide](../guides/integrations) for instructions.
Args:
save_directory (`str` or `Path`):
Path to directory in which the model weights and configuration will be saved.
"""
raise NotImplementedError
@classmethod
@validate_hf_hub_args
def from_pretrained(
cls: Type[T],
pretrained_model_name_or_path: Union[str, Path],
*,
force_download: bool = False,
resume_download: Optional[bool] = None,
proxies: Optional[Dict] = None,
token: Optional[Union[str, bool]] = None,
cache_dir: Optional[Union[str, Path]] = None,
local_files_only: bool = False,
revision: Optional[str] = None,
**model_kwargs,
) -> T:
"""
Download a model from the Huggingface Hub and instantiate it.
Args:
pretrained_model_name_or_path (`str`, `Path`):
- Either the `model_id` (string) of a model hosted on the Hub, e.g. `bigscience/bloom`.
- Or a path to a `directory` containing model weights saved using
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `../path/to/my_model_directory/`.
revision (`str`, *optional*):
Revision of the model on the Hub. Can be a branch name, a git tag or any commit id.
Defaults to the latest commit on `main` branch.
force_download (`bool`, *optional*, defaults to `False`):
Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
the existing cache.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
cached when running `huggingface-cli login`.
cache_dir (`str`, `Path`, *optional*):
Path to the folder where cached files are stored.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, avoid downloading the file and return the path to the local cached file if it exists.
model_kwargs (`Dict`, *optional*):
Additional kwargs to pass to the model during initialization.
"""
model_id = str(pretrained_model_name_or_path)
config_file: Optional[str] = None
if os.path.isdir(model_id):
if CONFIG_NAME in os.listdir(model_id):
config_file = os.path.join(model_id, CONFIG_NAME)
else:
logger.warning(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
else:
try:
config_file = hf_hub_download(
repo_id=model_id,
filename=CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError as e:
logger.info(f"{CONFIG_NAME} not found on the HuggingFace Hub: {str(e)}")
# Read config
config = None
if config_file is not None:
with open(config_file, "r", encoding="utf-8") as f:
config = json.load(f)
# Decode custom types in config
for key, value in config.items():
if key in cls._hub_mixin_init_parameters:
expected_type = cls._hub_mixin_init_parameters[key].annotation
if expected_type is not inspect.Parameter.empty:
config[key] = cls._decode_arg(expected_type, value)
# Populate model_kwargs from config
for param in cls._hub_mixin_init_parameters.values():
if param.name not in model_kwargs and param.name in config:
model_kwargs[param.name] = config[param.name]
# Check if `config` argument was passed at init
if "config" in cls._hub_mixin_init_parameters:
# Check if `config` argument is a dataclass
config_annotation = cls._hub_mixin_init_parameters["config"].annotation
if config_annotation is inspect.Parameter.empty:
pass # no annotation
elif is_dataclass(config_annotation):
config = _load_dataclass(config_annotation, config)
else:
# if Optional/Union annotation => check if a dataclass is in the Union
for _sub_annotation in get_args(config_annotation):
if is_dataclass(_sub_annotation):
config = _load_dataclass(_sub_annotation, config)
break
# Forward config to model initialization
model_kwargs["config"] = config
# Inject config if `**kwargs` are expected
if is_dataclass(cls):
for key in cls.__dataclass_fields__:
if key not in model_kwargs and key in config:
model_kwargs[key] = config[key]
elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()):
for key, value in config.items():
if key not in model_kwargs:
model_kwargs[key] = value
# Finally, also inject if `_from_pretrained` expects it
if cls._hub_mixin_inject_config:
model_kwargs["config"] = config
instance = cls._from_pretrained(
model_id=str(model_id),
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
**model_kwargs,
)
# Implicitly set the config as instance attribute if not already set by the class
# This way `config` will be available when calling `save_pretrained` or `push_to_hub`.
if config is not None and (getattr(instance, "_hub_mixin_config", None) in (None, {})):
instance._hub_mixin_config = config
return instance
@classmethod
def _from_pretrained(
cls: Type[T],
*,
model_id: str,
revision: Optional[str],
cache_dir: Optional[Union[str, Path]],
force_download: bool,
proxies: Optional[Dict],
resume_download: Optional[bool],
local_files_only: bool,
token: Optional[Union[str, bool]],
**model_kwargs,
) -> T:
"""Overwrite this method in subclass to define how to load your model from pretrained.
Use [`hf_hub_download`] or [`snapshot_download`] to download files from the Hub before loading them. Most
args taken as input can be directly passed to those 2 methods. If needed, you can add more arguments to this
method using "model_kwargs". For example [`PyTorchModelHubMixin._from_pretrained`] takes as input a `map_location`
parameter to set on which device the model should be loaded.
Check out our [integration guide](../guides/integrations) for more instructions.
Args:
model_id (`str`):
ID of the model to load from the Huggingface Hub (e.g. `bigscience/bloom`).
revision (`str`, *optional*):
Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. Defaults to the
latest commit on `main` branch.
force_download (`bool`, *optional*, defaults to `False`):
Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
the existing cache.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`).
token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
cached when running `huggingface-cli login`.
cache_dir (`str`, `Path`, *optional*):
Path to the folder where cached files are stored.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, avoid downloading the file and return the path to the local cached file if it exists.
model_kwargs:
Additional keyword arguments passed along to the [`~ModelHubMixin._from_pretrained`] method.
"""
raise NotImplementedError
@validate_hf_hub_args
def push_to_hub(
self,
repo_id: str,
*,
config: Optional[Union[dict, "DataclassInstance"]] = None,
commit_message: str = "Push model using huggingface_hub.",
private: bool = False,
token: Optional[str] = None,
branch: Optional[str] = None,
create_pr: Optional[bool] = None,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
delete_patterns: Optional[Union[List[str], str]] = None,
) -> str:
"""
Upload model checkpoint to the Hub.
Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
`delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
details.
Args:
repo_id (`str`):
ID of the repository to push to (example: `"username/my-model"`).
config (`dict` or `DataclassInstance`, *optional*):
Model configuration specified as a key/value dictionary or a dataclass instance.
commit_message (`str`, *optional*):
Message to commit while pushing.
private (`bool`, *optional*, defaults to `False`):
Whether the repository created should be private.
token (`str`, *optional*):
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
cached when running `huggingface-cli login`.
branch (`str`, *optional*):
The git branch on which to push the model. This defaults to `"main"`.
create_pr (`boolean`, *optional*):
Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
allow_patterns (`List[str]` or `str`, *optional*):
If provided, only files matching at least one pattern are pushed.
ignore_patterns (`List[str]` or `str`, *optional*):
If provided, files matching any of the patterns are not pushed.
delete_patterns (`List[str]` or `str`, *optional*):
If provided, remote files matching any of the patterns will be deleted from the repo.
Returns:
The url of the commit of your model in the given repository.
"""
api = HfApi(token=token)
repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
# Push the files to the repo in a single commit
with SoftTemporaryDirectory() as tmp:
saved_path = Path(tmp) / repo_id
self.save_pretrained(saved_path, config=config)
return api.upload_folder(
repo_id=repo_id,
repo_type="model",
folder_path=saved_path,
commit_message=commit_message,
revision=branch,
create_pr=create_pr,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
delete_patterns=delete_patterns,
)
def generate_model_card(self, *args, **kwargs) -> ModelCard:
card = ModelCard.from_template(
card_data=self._hub_mixin_info.model_card_data,
template_str=self._hub_mixin_info.model_card_template,
repo_url=self._hub_mixin_info.repo_url,
docs_url=self._hub_mixin_info.docs_url,
)
return card
class PyTorchModelHubMixin(ModelHubMixin):
"""
Implementation of [`ModelHubMixin`] to provide model Hub upload/download capabilities to PyTorch models. The model
is set in evaluation mode by default using `model.eval()` (dropout modules are deactivated). To train the model,
you should first set it back in training mode with `model.train()`.
See [`ModelHubMixin`] for more details on how to use the mixin.
Example:
```python
>>> import torch
>>> import torch.nn as nn
>>> from huggingface_hub import PyTorchModelHubMixin
>>> class MyModel(
... nn.Module,
... PyTorchModelHubMixin,
... library_name="keras-nlp",
... repo_url="https://github.com/keras-team/keras-nlp",
... docs_url="https://keras.io/keras_nlp/",
... # ^ optional metadata to generate model card
... ):
... def __init__(self, hidden_size: int = 512, vocab_size: int = 30000, output_size: int = 4):
... super().__init__()
... self.param = nn.Parameter(torch.rand(hidden_size, vocab_size))
... self.linear = nn.Linear(output_size, vocab_size)
... def forward(self, x):
... return self.linear(x + self.param)
>>> model = MyModel(hidden_size=256)
# Save model weights to local directory
>>> model.save_pretrained("my-awesome-model")
# Push model weights to the Hub
>>> model.push_to_hub("my-awesome-model")
# Download and initialize weights from the Hub
>>> model = MyModel.from_pretrained("username/my-awesome-model")
>>> model.hidden_size
256
```
"""
def __init_subclass__(cls, *args, tags: Optional[List[str]] = None, **kwargs) -> None:
tags = tags or []
tags.append("pytorch_model_hub_mixin")
kwargs["tags"] = tags
return super().__init_subclass__(*args, **kwargs)
def _save_pretrained(self, save_directory: Path) -> None:
"""Save weights from a Pytorch model to a local directory."""
model_to_save = self.module if hasattr(self, "module") else self # type: ignore
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
@classmethod
def _from_pretrained(
cls,
*,
model_id: str,
revision: Optional[str],
cache_dir: Optional[Union[str, Path]],
force_download: bool,
proxies: Optional[Dict],
resume_download: Optional[bool],
local_files_only: bool,
token: Union[str, bool, None],
map_location: str = "cpu",
strict: bool = False,
**model_kwargs,
):
"""Load Pytorch pretrained weights and return the loaded model."""
model = cls(**model_kwargs)
if os.path.isdir(model_id):
print("Loading weights from local directory")
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
return cls._load_as_safetensor(model, model_file, map_location, strict)
else:
try:
model_file = hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
return cls._load_as_safetensor(model, model_file, map_location, strict)
except EntryNotFoundError:
model_file = hf_hub_download(
repo_id=model_id,
filename=PYTORCH_WEIGHTS_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
return cls._load_as_pickle(model, model_file, map_location, strict)
@classmethod
def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
state_dict = torch.load(model_file, map_location=torch.device(map_location))
model.load_state_dict(state_dict, strict=strict) # type: ignore
model.eval() # type: ignore
return model
@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type]
if map_location != "cpu":
# TODO: remove this once https://github.com/huggingface/safetensors/pull/449 is merged.
logger.warning(
"Loading model weights on other devices than 'cpu' is not supported natively."
" This means that the model is loaded on 'cpu' first and then copied to the device."
" This leads to a slower loading time."
" Support for loading directly on other devices is planned to be added in future releases."
" See https://github.com/huggingface/huggingface_hub/pull/2086 for more details."
)
model.to(map_location) # type: ignore [attr-defined]
return model
def _load_dataclass(datacls: Type["DataclassInstance"], data: dict) -> "DataclassInstance":
"""Load a dataclass instance from a dictionary.
Fields not expected by the dataclass are ignored.
"""
return datacls(**{k: v for k, v in data.items() if k in datacls.__dataclass_fields__})