Skip to content

Commit

Permalink
Add a default dtype to model configs (#330)
Browse files Browse the repository at this point in the history
* Add a default `dtype` to model configs

* Add `TransformerConfig` to docs

* Make `ConfigT` typevar bound on just `TransformerConfig`
  • Loading branch information
shadeMe committed Sep 19, 2023
1 parent 01e8902 commit 84246c9
Show file tree
Hide file tree
Showing 13 changed files with 82 additions and 47 deletions.
2 changes: 2 additions & 0 deletions curated_transformers/models/__init__.py
Expand Up @@ -5,6 +5,7 @@
from .config import (
RotaryEmbeddingConfig,
TransformerAttentionLayerConfig,
TransformerConfig,
TransformerEmbeddingLayerConfig,
TransformerFeedForwardLayerConfig,
TransformerLayerConfig,
Expand Down Expand Up @@ -52,6 +53,7 @@
"RoBERTaEncoder",
"RotaryEmbeddingConfig",
"TransformerAttentionLayerConfig",
"TransformerConfig",
"TransformerCausalLM",
"TransformerDecoder",
"TransformerEmbeddingLayerConfig",
Expand Down
6 changes: 5 additions & 1 deletion curated_transformers/models/albert/config.py
@@ -1,8 +1,11 @@
from dataclasses import dataclass

import torch

from ...layers.activations import Activation
from ..config import (
TransformerAttentionLayerConfig,
TransformerConfig,
TransformerEmbeddingLayerConfig,
TransformerFeedForwardLayerConfig,
TransformerLayerConfig,
Expand Down Expand Up @@ -35,7 +38,7 @@ def __init__(


@dataclass
class ALBERTConfig:
class ALBERTConfig(TransformerConfig):
"""
ALBERT (`Lan et al., 2022`_) model configuration.
Expand Down Expand Up @@ -129,4 +132,5 @@ def __init__(
n_layers_per_group=n_layers_per_group,
n_hidden_groups=n_hidden_groups,
)
self.dtype = torch.float32
self.model_max_length = model_max_length
20 changes: 10 additions & 10 deletions curated_transformers/models/auto_model.py
Expand Up @@ -11,7 +11,7 @@
from .albert import ALBERTEncoder
from .bert import BERTEncoder
from .camembert import CamemBERTEncoder
from .config import ConfigDataclass
from .config import TransformerConfig
from .falcon import FalconCausalLM, FalconDecoder
from .gpt_neox import GPTNeoXCausalLM, GPTNeoXDecoder
from .hf_hub import FromHFHub
Expand Down Expand Up @@ -185,7 +185,7 @@ def from_hf_hub_to_cache(
module_cls.from_hf_hub_to_cache(name=name, revision=revision)


class AutoEncoder(AutoModel[EncoderModule[ConfigDataclass]]):
class AutoEncoder(AutoModel[EncoderModule[TransformerConfig]]):
"""
Encoder model loaded from the Hugging Face Model Hub.
"""
Expand All @@ -207,7 +207,7 @@ def from_fsspec(
fsspec_args: Optional[Dict[str, Any]] = None,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> EncoderModule[ConfigDataclass]:
) -> EncoderModule[TransformerConfig]:
encoder = cls._instantiate_model_from_fsspec(
fs, model_path, fsspec_args, device, quantization_config
)
Expand All @@ -222,15 +222,15 @@ def from_hf_hub(
revision: str = "main",
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> EncoderModule[ConfigDataclass]:
) -> EncoderModule[TransformerConfig]:
encoder = cls._instantiate_model_from_hf_hub(
name, revision, device, quantization_config
)
assert isinstance(encoder, EncoderModule)
return encoder


class AutoDecoder(AutoModel[DecoderModule[ConfigDataclass, KeyValueCache]]):
class AutoDecoder(AutoModel[DecoderModule[TransformerConfig, KeyValueCache]]):
"""
Decoder module loaded from the Hugging Face Model Hub.
"""
Expand All @@ -253,7 +253,7 @@ def from_fsspec(
fsspec_args: Optional[Dict[str, Any]] = None,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> DecoderModule[ConfigDataclass, KeyValueCache]:
) -> DecoderModule[TransformerConfig, KeyValueCache]:
decoder = cls._instantiate_model_from_fsspec(
fs, model_path, fsspec_args, device, quantization_config
)
Expand All @@ -268,15 +268,15 @@ def from_hf_hub(
revision: str = "main",
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> DecoderModule[ConfigDataclass, KeyValueCache]:
) -> DecoderModule[TransformerConfig, KeyValueCache]:
decoder = cls._instantiate_model_from_hf_hub(
name, revision, device, quantization_config
)
assert isinstance(decoder, DecoderModule)
return decoder


class AutoCausalLM(AutoModel[CausalLMModule[ConfigDataclass, KeyValueCache]]):
class AutoCausalLM(AutoModel[CausalLMModule[TransformerConfig, KeyValueCache]]):
"""
Causal LM model loaded from the Hugging Face Model Hub.
"""
Expand All @@ -299,7 +299,7 @@ def from_fsspec(
fsspec_args: Optional[Dict[str, Any]] = None,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> CausalLMModule[ConfigDataclass, KeyValueCache]:
) -> CausalLMModule[TransformerConfig, KeyValueCache]:
causal_lm = cls._instantiate_model_from_fsspec(
fs, model_path, fsspec_args, device, quantization_config
)
Expand All @@ -314,7 +314,7 @@ def from_hf_hub(
revision: str = "main",
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> CausalLMModule[ConfigDataclass, KeyValueCache]:
) -> CausalLMModule[TransformerConfig, KeyValueCache]:
causal_lm = cls._instantiate_model_from_hf_hub(
name, revision, device, quantization_config
)
Expand Down
8 changes: 5 additions & 3 deletions curated_transformers/models/bert/config.py
@@ -1,24 +1,25 @@
from dataclasses import dataclass

import torch

from ...layers.activations import Activation
from ..config import (
TransformerAttentionLayerConfig,
TransformerConfig,
TransformerEmbeddingLayerConfig,
TransformerFeedForwardLayerConfig,
TransformerLayerConfig,
)


@dataclass
class BERTConfig:
class BERTConfig(TransformerConfig):
"""
BERT (`Devlin et al., 2018`_) model configuration.
.. _Devlin et al., 2018 : https://arxiv.org/abs/1810.04805
"""

embedding: TransformerEmbeddingLayerConfig
layer: TransformerLayerConfig
model_max_length: int

def __init__(
Expand Down Expand Up @@ -98,4 +99,5 @@ def __init__(
layer_norm_eps=layer_norm_eps,
dropout_prob=hidden_dropout_prob,
)
self.dtype = torch.float32
self.model_max_length = model_max_length
31 changes: 22 additions & 9 deletions curated_transformers/models/config.py
@@ -1,15 +1,9 @@
from dataclasses import dataclass
from typing import ClassVar, Optional, Protocol
from typing import Optional

from ..layers.activations import Activation


class ConfigDataclass(Protocol):
"""
Protocol that describes a config data class.
"""
import torch

__dataclass_fields__: ClassVar[dict]
from ..layers.activations import Activation


@dataclass
Expand Down Expand Up @@ -136,3 +130,22 @@ class TransformerLayerConfig:
feedforward: TransformerFeedForwardLayerConfig
layer_norm_eps: float
n_hidden_layers: int


@dataclass
class TransformerConfig:
"""
Configuration options for a transformer model.
:param embedding:
Embedding layer config.
:param layer:
Transformer hidden layer config.
:param dtype:
Default data type used by the model's
parameters.
"""

embedding: TransformerEmbeddingLayerConfig
layer: TransformerLayerConfig
dtype: torch.dtype
8 changes: 5 additions & 3 deletions curated_transformers/models/falcon/config.py
@@ -1,25 +1,26 @@
from dataclasses import dataclass

import torch

from ...layers.activations import Activation
from ..config import (
RotaryEmbeddingConfig,
TransformerAttentionLayerConfig,
TransformerConfig,
TransformerEmbeddingLayerConfig,
TransformerFeedForwardLayerConfig,
TransformerLayerConfig,
)


@dataclass
class FalconConfig:
class FalconConfig(TransformerConfig):
"""
Falcon (`Penedo et al., 2019`_) model configuration.
.. _Penedo et al., 2019: https://arxiv.org/abs/2306.01116
"""

embedding: TransformerEmbeddingLayerConfig
layer: TransformerLayerConfig
new_decoder_architecture: bool

def __init__(
Expand Down Expand Up @@ -108,4 +109,5 @@ def __init__(
layer_norm_eps=layer_norm_eps,
n_hidden_layers=n_hidden_layers,
)
self.dtype = torch.bfloat16
self.new_decoder_architecture = new_decoder_architecture
9 changes: 5 additions & 4 deletions curated_transformers/models/gpt_neox/config.py
@@ -1,26 +1,26 @@
from dataclasses import dataclass

import torch

from ...layers.activations import Activation
from ..config import (
RotaryEmbeddingConfig,
TransformerAttentionLayerConfig,
TransformerConfig,
TransformerEmbeddingLayerConfig,
TransformerFeedForwardLayerConfig,
TransformerLayerConfig,
)


@dataclass
class GPTNeoXConfig:
class GPTNeoXConfig(TransformerConfig):
"""
GPT-NeoX (`Black et al., 2022`_) model configuration.
.. _Black et al., 2022: https://arxiv.org/abs/2204.06745
"""

embedding: TransformerEmbeddingLayerConfig
layer: TransformerLayerConfig

def __init__(
self,
*,
Expand Down Expand Up @@ -103,3 +103,4 @@ def __init__(
layer_norm_eps=layer_norm_eps,
n_hidden_layers=n_hidden_layers,
)
self.dtype = torch.float16
16 changes: 10 additions & 6 deletions curated_transformers/models/hf_hub.py
Expand Up @@ -24,6 +24,7 @@
from ..util.fsspec import get_model_config as get_model_config_fsspec
from ..util.hf import get_model_checkpoint_files, get_model_config
from ..util.serde import ModelCheckpointType, ModelFile, load_model_from_checkpoints
from .module import TransformerModule

# Only provided as typing.Self in Python 3.11+.
Self = TypeVar("Self", bound="FromHFHub")
Expand Down Expand Up @@ -193,12 +194,15 @@ def _create_and_load_model(
model = cls.from_hf_config(hf_config=config, device=torch.device("meta"))

# Convert the model to the expected dtype.
dtype_str = config.get("torch_dtype")
if dtype_str is not None:
dtype = getattr(torch, dtype_str, None)
if dtype is None or not isinstance(dtype, torch.dtype):
raise ValueError(f"Invalid torch dtype `{dtype_str}`")
model.to(dtype=dtype)
assert isinstance(model, TransformerModule)
dtype: torch.dtype = model.config.dtype
serialized_dtype_str = config.get("torch_dtype")
if serialized_dtype_str is not None:
serialized_dtype = getattr(torch, serialized_dtype_str, None)
if not isinstance(serialized_dtype, torch.dtype):
raise ValueError(f"Invalid torch dtype `{serialized_dtype_str}`")
dtype = serialized_dtype
model.to(dtype=dtype)

# Prepare for quantization.
if quantization_config is not None:
Expand Down
9 changes: 5 additions & 4 deletions curated_transformers/models/llama/config.py
@@ -1,27 +1,27 @@
from dataclasses import dataclass

import torch

from ...layers.activations import Activation
from ..config import (
RotaryEmbeddingConfig,
TransformerAttentionLayerConfig,
TransformerConfig,
TransformerEmbeddingLayerConfig,
TransformerFeedForwardLayerConfig,
TransformerLayerConfig,
)


@dataclass
class LlamaConfig:
class LlamaConfig(TransformerConfig):
"""
Llama (`Touvron et al., 2023 [a]`_, `Touvron et al., 2023 [b]`_) model configuration.
.. _Touvron et al., 2023 [a]: https://arxiv.org/abs/2302.13971
.. _Touvron et al., 2023 [b]: https://arxiv.org/abs/2307.09288
"""

embedding: TransformerEmbeddingLayerConfig
layer: TransformerLayerConfig

def __init__(
self,
*,
Expand Down Expand Up @@ -100,3 +100,4 @@ def __init__(
layer_norm_eps=rms_norm_eps,
n_hidden_layers=n_hidden_layers,
)
self.dtype = torch.float16
4 changes: 2 additions & 2 deletions curated_transformers/models/module.py
Expand Up @@ -5,10 +5,10 @@
from torch.nn import Module

from ..layers.attention import AttentionMask
from .config import ConfigDataclass
from .config import TransformerConfig
from .output import CacheT, CausalLMOutputWithCache, ModelOutput, ModelOutputWithCache

ConfigT = TypeVar("ConfigT", bound=ConfigDataclass)
ConfigT = TypeVar("ConfigT", bound=TransformerConfig)


class TransformerModule(Generic[ConfigT], Module):
Expand Down
9 changes: 5 additions & 4 deletions curated_transformers/models/mpt/config.py
@@ -1,25 +1,25 @@
from dataclasses import dataclass

import torch

from ...layers.activations import Activation
from ..config import (
TransformerAttentionLayerConfig,
TransformerConfig,
TransformerEmbeddingLayerConfig,
TransformerFeedForwardLayerConfig,
TransformerLayerConfig,
)


@dataclass
class MPTConfig:
class MPTConfig(TransformerConfig):
"""
`MosaicML MPT`_ model configuration.
.. _MosaicML MPT: https://www.mosaicml.com/blog/mpt-7b
"""

embedding: TransformerEmbeddingLayerConfig
layer: TransformerLayerConfig

def __init__(
self,
*,
Expand Down Expand Up @@ -94,3 +94,4 @@ def __init__(
layer_norm_eps=layer_norm_eps,
n_hidden_layers=n_hidden_layers,
)
self.dtype = torch.bfloat16

0 comments on commit 84246c9

Please sign in to comment.