Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add repository abstraction #331

Merged
merged 6 commits into from Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
187 changes: 61 additions & 126 deletions curated_transformers/models/auto_model.py
@@ -1,13 +1,14 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, Optional, Type, TypeVar
from typing import Dict, Generic, Optional, Type, TypeVar

import torch
from fsspec import AbstractFileSystem

from ..layers.cache import KeyValueCache
from ..quantization.bnb.config import BitsAndBytesConfig
from ..util.fsspec import get_config_model_type as get_config_model_type_fsspec
from ..util.hf import get_config_model_type
from ..repository.fsspec import FsspecArgs, FsspecRepository
from ..repository.hf_hub import HfHubRepository
from ..repository.repository import ModelRepository, Repository
from .albert import ALBERTEncoder
from .bert import BERTEncoder
from .camembert import CamemBERTEncoder
Expand All @@ -33,36 +34,12 @@ class AutoModel(ABC, Generic[ModelT]):

_hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = {}

@classmethod
def _resolve_model_cls_fsspec(
cls,
fs: AbstractFileSystem,
model_path: str,
fsspec_args: Optional[Dict[str, Any]] = None,
) -> Type[FromHFHub]:
model_type = get_config_model_type_fsspec(
fs, model_path, fsspec_args=fsspec_args
)
if model_type is None:
raise ValueError(
"The model type is not defined in the model configuration."
)
module_cls = cls._hf_model_type_to_curated.get(model_type)
if module_cls is None:
raise ValueError(
f"Unsupported model type `{model_type}` for {cls.__name__}. "
f"Supported model types: {tuple(cls._hf_model_type_to_curated.keys())}"
)
assert issubclass(module_cls, FromHFHub)
return module_cls

@classmethod
def _resolve_model_cls(
cls,
name: str,
revision: str,
repo: ModelRepository,
) -> Type[FromHFHub]:
model_type = get_config_model_type(name, revision)
model_type = repo.model_type()
module_cls = cls._hf_model_type_to_curated.get(model_type)
if module_cls is None:
raise ValueError(
Expand All @@ -73,36 +50,15 @@ def _resolve_model_cls(
return module_cls

@classmethod
def _instantiate_model_from_fsspec(
cls,
fs: AbstractFileSystem,
model_path: str,
fsspec_args: Optional[Dict[str, Any]],
device: Optional[torch.device],
quantization_config: Optional[BitsAndBytesConfig],
) -> FromHFHub:
module_cls = cls._resolve_model_cls_fsspec(fs, model_path)
module = module_cls.from_fsspec(
fs=fs,
model_path=model_path,
fsspec_args=fsspec_args,
device=device,
quantization_config=quantization_config,
)
return module

@classmethod
def _instantiate_model_from_hf_hub(
def _instantiate_model(
cls,
name: str,
revision: str,
repo: Repository,
device: Optional[torch.device],
quantization_config: Optional[BitsAndBytesConfig],
) -> FromHFHub:
module_cls = cls._resolve_model_cls(name, revision)
module = module_cls.from_hf_hub(
name=name,
revision=revision,
module_cls = cls._resolve_model_cls(ModelRepository(repo))
module = module_cls.from_repo(
repo=repo,
device=device,
quantization_config=quantization_config,
)
Expand All @@ -114,7 +70,7 @@ def from_fsspec(
*,
fs: AbstractFileSystem,
model_path: str,
fsspec_args: Optional[Dict[str, Any]] = None,
fsspec_args: Optional[FsspecArgs] = None,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> ModelT:
Expand All @@ -135,10 +91,17 @@ def from_fsspec(
:returns:
Module with the parameters loaded.
"""
raise NotImplementedError
return cls.from_repo(
repo=FsspecRepository(
fs,
model_path=model_path,
fsspec_args=fsspec_args,
),
device=device,
quantization_config=quantization_config,
)

@classmethod
@abstractmethod
def from_hf_hub(
cls,
*,
Expand All @@ -161,6 +124,34 @@ def from_hf_hub(
:returns:
Loaded model or generator.
"""
return cls.from_repo(
repo=HfHubRepository(name=name, revision=revision),
device=device,
quantization_config=quantization_config,
)

@classmethod
@abstractmethod
def from_repo(
cls,
*,
repo: Repository,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> ModelT:
"""
Construct and load a model or a generator from a repository.

:param repository:
The repository to load from.
:param device:
Device on which to initialize the model.
:param quantization_config:
Configuration for loading quantized weights.
:returns:
Loaded model or generator.
"""

raise NotImplementedError

@classmethod
Expand All @@ -181,8 +172,9 @@ def from_hf_hub_to_cache(
:param revision:
Model revision.
"""
module_cls = cls._resolve_model_cls(name, revision)
module_cls.from_hf_hub_to_cache(name=name, revision=revision)
repo = ModelRepository(HfHubRepository(name=name, revision=revision))
repo.model_config()
repo.model_checkpoints()


class AutoEncoder(AutoModel[EncoderModule[TransformerConfig]]):
Expand All @@ -199,33 +191,14 @@ class AutoEncoder(AutoModel[EncoderModule[TransformerConfig]]):
}

@classmethod
def from_fsspec(
def from_repo(
cls,
*,
fs: AbstractFileSystem,
model_path: str,
fsspec_args: Optional[Dict[str, Any]] = None,
repo: Repository,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> EncoderModule[TransformerConfig]:
encoder = cls._instantiate_model_from_fsspec(
fs, model_path, fsspec_args, device, quantization_config
)
assert isinstance(encoder, EncoderModule)
return encoder

@classmethod
def from_hf_hub(
cls,
*,
name: str,
revision: str = "main",
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> EncoderModule[TransformerConfig]:
encoder = cls._instantiate_model_from_hf_hub(
name, revision, device, quantization_config
)
encoder = cls._instantiate_model(repo, device, quantization_config)
assert isinstance(encoder, EncoderModule)
return encoder

Expand All @@ -245,33 +218,14 @@ class AutoDecoder(AutoModel[DecoderModule[TransformerConfig, KeyValueCache]]):
}

@classmethod
def from_fsspec(
cls,
*,
fs: AbstractFileSystem,
model_path: str,
fsspec_args: Optional[Dict[str, Any]] = None,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> DecoderModule[TransformerConfig, KeyValueCache]:
decoder = cls._instantiate_model_from_fsspec(
fs, model_path, fsspec_args, device, quantization_config
)
assert isinstance(decoder, DecoderModule)
return decoder

@classmethod
def from_hf_hub(
def from_repo(
cls,
*,
name: str,
revision: str = "main",
repo: Repository,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> DecoderModule[TransformerConfig, KeyValueCache]:
decoder = cls._instantiate_model_from_hf_hub(
name, revision, device, quantization_config
)
decoder = cls._instantiate_model(repo, device, quantization_config)
assert isinstance(decoder, DecoderModule)
return decoder

Expand All @@ -291,32 +245,13 @@ class AutoCausalLM(AutoModel[CausalLMModule[TransformerConfig, KeyValueCache]]):
}

@classmethod
def from_fsspec(
def from_repo(
cls,
*,
fs: AbstractFileSystem,
model_path: str,
fsspec_args: Optional[Dict[str, Any]] = None,
repo: Repository,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> CausalLMModule[TransformerConfig, KeyValueCache]:
causal_lm = cls._instantiate_model_from_fsspec(
fs, model_path, fsspec_args, device, quantization_config
)
assert isinstance(causal_lm, CausalLMModule)
return causal_lm

@classmethod
def from_hf_hub(
cls,
*,
name: str,
revision: str = "main",
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> CausalLMModule[TransformerConfig, KeyValueCache]:
causal_lm = cls._instantiate_model_from_hf_hub(
name, revision, device, quantization_config
)
causal_lm = cls._instantiate_model(repo, device, quantization_config)
assert isinstance(causal_lm, CausalLMModule)
return causal_lm
53 changes: 29 additions & 24 deletions curated_transformers/models/hf_hub.py
Expand Up @@ -18,12 +18,10 @@

from ..quantization import prepare_module_for_quantization
from ..quantization.bnb.config import BitsAndBytesConfig
from ..util.fsspec import (
get_model_checkpoint_files as get_model_checkpoint_files_fsspec,
)
from ..util.fsspec import get_model_config as get_model_config_fsspec
from ..util.hf import get_model_checkpoint_files, get_model_config
from ..util.serde import ModelCheckpointType, ModelFile, load_model_from_checkpoints
from ..repository.fsspec import FsspecArgs, FsspecRepository
from ..repository.hf_hub import HfHubRepository
from ..repository.repository import ModelRepository, Repository
from ..util.serde import load_model_from_checkpoints
from .module import TransformerModule

# Only provided as typing.Self in Python 3.11+.
Expand Down Expand Up @@ -94,16 +92,17 @@ def from_hf_hub_to_cache(
:param revision:
Model revision.
"""
_ = get_model_config(name, revision)
_ = get_model_checkpoint_files(name, revision)
repo = ModelRepository(HfHubRepository(name=name, revision=revision))
repo.model_config()
repo.model_checkpoints()

@classmethod
def from_fsspec(
cls: Type[Self],
*,
fs: AbstractFileSystem,
model_path: str,
fsspec_args: Optional[Dict[str, Any]] = None,
fsspec_args: Optional[FsspecArgs] = None,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> Self:
Expand All @@ -124,13 +123,8 @@ def from_fsspec(
:returns:
Module with the parameters loaded.
"""
return cls._create_and_load_model(
get_config=lambda: get_model_config_fsspec(
fs, model_path, fsspec_args=fsspec_args
),
get_checkpoint_files=lambda: get_model_checkpoint_files_fsspec(
fs, model_path, fsspec_args=fsspec_args
),
return cls.from_repo(
repo=FsspecRepository(fs, model_path, fsspec_args),
device=device,
quantization_config=quantization_config,
)
Expand Down Expand Up @@ -158,9 +152,8 @@ def from_hf_hub(
:returns:
Module with the parameters loaded.
"""
return cls._create_and_load_model(
get_config=lambda: get_model_config(name, revision),
get_checkpoint_files=lambda: get_model_checkpoint_files(name, revision),
return cls.from_repo(
repo=HfHubRepository(name=name, revision=revision),
device=device,
quantization_config=quantization_config,
)
Expand All @@ -182,15 +175,27 @@ def to(
...

@classmethod
def _create_and_load_model(
def from_repo(
cls: Type[Self],
*,
get_config: Callable[[], Dict[Any, str]],
get_checkpoint_files: Callable[[], Tuple[List[ModelFile], ModelCheckpointType]],
repo: Repository,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> Self:
config = get_config()
"""
Construct and load a model from a repository.

:param repository:
The repository to load from.
:param device:
Device on which to initialize the model.
:param quantization_config:
Configuration for loading quantized weights.
:returns:
Loaded model.
"""
model_repo = ModelRepository(repo)
config = model_repo.model_config()
model = cls.from_hf_config(hf_config=config, device=torch.device("meta"))

# Convert the model to the expected dtype.
Expand All @@ -211,7 +216,7 @@ def _create_and_load_model(
tensor2param = None

# Download model and convert HF parameter names to ours.
checkpoint_filenames, checkpoint_type = get_checkpoint_files()
checkpoint_filenames, checkpoint_type = model_repo.model_checkpoints()
load_model_from_checkpoints(
model, # type:ignore
filepaths=checkpoint_filenames,
Expand Down