Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a default dtype to model configs #330

Merged
merged 3 commits into from Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions curated_transformers/models/__init__.py
Expand Up @@ -5,6 +5,7 @@
from .config import (
RotaryEmbeddingConfig,
TransformerAttentionLayerConfig,
TransformerConfig,
TransformerEmbeddingLayerConfig,
TransformerFeedForwardLayerConfig,
TransformerLayerConfig,
Expand Down Expand Up @@ -52,6 +53,7 @@
"RoBERTaEncoder",
"RotaryEmbeddingConfig",
"TransformerAttentionLayerConfig",
"TransformerConfig",
"TransformerCausalLM",
"TransformerDecoder",
"TransformerEmbeddingLayerConfig",
Expand Down
6 changes: 5 additions & 1 deletion curated_transformers/models/albert/config.py
@@ -1,8 +1,11 @@
from dataclasses import dataclass

import torch

from ...layers.activations import Activation
from ..config import (
TransformerAttentionLayerConfig,
TransformerConfig,
TransformerEmbeddingLayerConfig,
TransformerFeedForwardLayerConfig,
TransformerLayerConfig,
Expand Down Expand Up @@ -35,7 +38,7 @@ def __init__(


@dataclass
class ALBERTConfig:
class ALBERTConfig(TransformerConfig):
"""
ALBERT (`Lan et al., 2022`_) model configuration.

Expand Down Expand Up @@ -129,4 +132,5 @@ def __init__(
n_layers_per_group=n_layers_per_group,
n_hidden_groups=n_hidden_groups,
)
self.dtype = torch.float32
self.model_max_length = model_max_length
8 changes: 5 additions & 3 deletions curated_transformers/models/bert/config.py
@@ -1,24 +1,25 @@
from dataclasses import dataclass

import torch

from ...layers.activations import Activation
from ..config import (
TransformerAttentionLayerConfig,
TransformerConfig,
TransformerEmbeddingLayerConfig,
TransformerFeedForwardLayerConfig,
TransformerLayerConfig,
)


@dataclass
class BERTConfig:
class BERTConfig(TransformerConfig):
"""
BERT (`Devlin et al., 2018`_) model configuration.

.. _Devlin et al., 2018 : https://arxiv.org/abs/1810.04805
"""

embedding: TransformerEmbeddingLayerConfig
layer: TransformerLayerConfig
model_max_length: int

def __init__(
Expand Down Expand Up @@ -98,4 +99,5 @@ def __init__(
layer_norm_eps=layer_norm_eps,
dropout_prob=hidden_dropout_prob,
)
self.dtype = torch.float32
self.model_max_length = model_max_length
21 changes: 21 additions & 0 deletions curated_transformers/models/config.py
@@ -1,6 +1,8 @@
from dataclasses import dataclass
from typing import ClassVar, Optional, Protocol

import torch

from ..layers.activations import Activation


Expand Down Expand Up @@ -136,3 +138,22 @@ class TransformerLayerConfig:
feedforward: TransformerFeedForwardLayerConfig
layer_norm_eps: float
n_hidden_layers: int


@dataclass
class TransformerConfig:
"""
Configuration options for a transformer model.

:param embedding:
Embedding layer config.
:param layer:
Transformer hidden layer config.
:param dtype:
Default data type used by the model's
parameters.
"""

embedding: TransformerEmbeddingLayerConfig
layer: TransformerLayerConfig
dtype: torch.dtype
8 changes: 5 additions & 3 deletions curated_transformers/models/falcon/config.py
@@ -1,25 +1,26 @@
from dataclasses import dataclass

import torch

from ...layers.activations import Activation
from ..config import (
RotaryEmbeddingConfig,
TransformerAttentionLayerConfig,
TransformerConfig,
TransformerEmbeddingLayerConfig,
TransformerFeedForwardLayerConfig,
TransformerLayerConfig,
)


@dataclass
class FalconConfig:
class FalconConfig(TransformerConfig):
"""
Falcon (`Penedo et al., 2019`_) model configuration.

.. _Penedo et al., 2019: https://arxiv.org/abs/2306.01116
"""

embedding: TransformerEmbeddingLayerConfig
layer: TransformerLayerConfig
new_decoder_architecture: bool

def __init__(
Expand Down Expand Up @@ -108,4 +109,5 @@ def __init__(
layer_norm_eps=layer_norm_eps,
n_hidden_layers=n_hidden_layers,
)
self.dtype = torch.bfloat16
self.new_decoder_architecture = new_decoder_architecture
9 changes: 5 additions & 4 deletions curated_transformers/models/gpt_neox/config.py
@@ -1,26 +1,26 @@
from dataclasses import dataclass

import torch

from ...layers.activations import Activation
from ..config import (
RotaryEmbeddingConfig,
TransformerAttentionLayerConfig,
TransformerConfig,
TransformerEmbeddingLayerConfig,
TransformerFeedForwardLayerConfig,
TransformerLayerConfig,
)


@dataclass
class GPTNeoXConfig:
class GPTNeoXConfig(TransformerConfig):
"""
GPT-NeoX (`Black et al., 2022`_) model configuration.

.. _Black et al., 2022: https://arxiv.org/abs/2204.06745
"""

embedding: TransformerEmbeddingLayerConfig
layer: TransformerLayerConfig

def __init__(
self,
*,
Expand Down Expand Up @@ -103,3 +103,4 @@ def __init__(
layer_norm_eps=layer_norm_eps,
n_hidden_layers=n_hidden_layers,
)
self.dtype = torch.float16
16 changes: 10 additions & 6 deletions curated_transformers/models/hf_hub.py
Expand Up @@ -24,6 +24,7 @@
from ..util.fsspec import get_model_config as get_model_config_fsspec
from ..util.hf import get_model_checkpoint_files, get_model_config
from ..util.serde import ModelCheckpointType, ModelFile, load_model_from_checkpoints
from .module import TransformerModule

# Only provided as typing.Self in Python 3.11+.
Self = TypeVar("Self", bound="FromHFHub")
Expand Down Expand Up @@ -193,12 +194,15 @@ def _create_and_load_model(
model = cls.from_hf_config(hf_config=config, device=torch.device("meta"))

# Convert the model to the expected dtype.
dtype_str = config.get("torch_dtype")
if dtype_str is not None:
dtype = getattr(torch, dtype_str, None)
if dtype is None or not isinstance(dtype, torch.dtype):
raise ValueError(f"Invalid torch dtype `{dtype_str}`")
model.to(dtype=dtype)
assert isinstance(model, TransformerModule)
dtype: torch.dtype = model.config.dtype
serialized_dtype_str = config.get("torch_dtype")
if serialized_dtype_str is not None:
serialized_dtype = getattr(torch, serialized_dtype_str, None)
if not isinstance(serialized_dtype, torch.dtype):
raise ValueError(f"Invalid torch dtype `{serialized_dtype_str}`")
dtype = serialized_dtype
model.to(dtype=dtype)

# Prepare for quantization.
if quantization_config is not None:
Expand Down
9 changes: 5 additions & 4 deletions curated_transformers/models/llama/config.py
@@ -1,27 +1,27 @@
from dataclasses import dataclass

import torch

from ...layers.activations import Activation
from ..config import (
RotaryEmbeddingConfig,
TransformerAttentionLayerConfig,
TransformerConfig,
TransformerEmbeddingLayerConfig,
TransformerFeedForwardLayerConfig,
TransformerLayerConfig,
)


@dataclass
class LlamaConfig:
class LlamaConfig(TransformerConfig):
"""
Llama (`Touvron et al., 2023 [a]`_, `Touvron et al., 2023 [b]`_) model configuration.

.. _Touvron et al., 2023 [a]: https://arxiv.org/abs/2302.13971
.. _Touvron et al., 2023 [b]: https://arxiv.org/abs/2307.09288
"""

embedding: TransformerEmbeddingLayerConfig
layer: TransformerLayerConfig

def __init__(
self,
*,
Expand Down Expand Up @@ -100,3 +100,4 @@ def __init__(
layer_norm_eps=rms_norm_eps,
n_hidden_layers=n_hidden_layers,
)
self.dtype = torch.float16
6 changes: 3 additions & 3 deletions curated_transformers/models/module.py
@@ -1,14 +1,14 @@
from abc import abstractmethod
from typing import Generic, List, Optional, TypeVar
from typing import Generic, List, Optional, TypeVar, Union

from torch import Tensor
from torch.nn import Module

from ..layers.attention import AttentionMask
from .config import ConfigDataclass
from .config import ConfigDataclass, TransformerConfig
from .output import CacheT, CausalLMOutputWithCache, ModelOutput, ModelOutputWithCache

ConfigT = TypeVar("ConfigT", bound=ConfigDataclass)
ConfigT = TypeVar("ConfigT", bound=Union[ConfigDataclass, TransformerConfig])
shadeMe marked this conversation as resolved.
Show resolved Hide resolved


class TransformerModule(Generic[ConfigT], Module):
Expand Down
9 changes: 5 additions & 4 deletions curated_transformers/models/mpt/config.py
@@ -1,25 +1,25 @@
from dataclasses import dataclass

import torch

from ...layers.activations import Activation
from ..config import (
TransformerAttentionLayerConfig,
TransformerConfig,
TransformerEmbeddingLayerConfig,
TransformerFeedForwardLayerConfig,
TransformerLayerConfig,
)


@dataclass
class MPTConfig:
class MPTConfig(TransformerConfig):
"""
`MosaicML MPT`_ model configuration.

.. _MosaicML MPT: https://www.mosaicml.com/blog/mpt-7b
"""

embedding: TransformerEmbeddingLayerConfig
layer: TransformerLayerConfig

def __init__(
self,
*,
Expand Down Expand Up @@ -94,3 +94,4 @@ def __init__(
layer_norm_eps=layer_norm_eps,
n_hidden_layers=n_hidden_layers,
)
self.dtype = torch.bfloat16
4 changes: 3 additions & 1 deletion curated_transformers/models/roberta/config.py
@@ -1,5 +1,7 @@
from dataclasses import dataclass

import torch

from ..bert import BERTConfig


Expand Down Expand Up @@ -61,5 +63,5 @@ def __init__(
n_pieces=n_pieces,
**kwargs
)

self.dtype = torch.float32
self.padding_id = padding_id
3 changes: 3 additions & 0 deletions docs/source/building-blocks.rst
Expand Up @@ -151,3 +151,6 @@ These dataclasses encapsulate the configurable parameters of the Transformer mod

.. autoclass:: curated_transformers.models.TransformerLayerConfig
:members:

.. autoclass:: curated_transformers.models.TransformerConfig
:members: