Skip to content

Commit

Permalink
Set default revision for tiiuae/* repos to models in the in-library
Browse files Browse the repository at this point in the history
format
  • Loading branch information
borzunov committed Sep 3, 2023
1 parent 97dd376 commit cac654a
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
fail-fast: false
runs-on: ${{ matrix.os }}-latest
timeout-minutes: 15
timeout-minutes: 20
steps:
- name: Increase swap space
if: ${{ matrix.os == 'ubuntu' }}
Expand Down
3 changes: 2 additions & 1 deletion src/petals/models/falcon/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from petals.client.lm_head import LMHeadConfig
from petals.client.ptune import PTuneConfig
from petals.models.falcon.block import WrappedFalconBlock
from petals.utils.auto_config import DefaultRevisionMixin

logger = get_logger(__name__)


class DistributedFalconConfig(FalconConfig, ClientConfig, PTuneConfig, LMHeadConfig):
class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig, PTuneConfig, LMHeadConfig):
block_class = WrappedFalconBlock
attn_class = FalconAttention
block_prefix = "transformer.h"
Expand Down
10 changes: 6 additions & 4 deletions src/petals/models/falcon/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
from petals.client.remote_sequential import RemoteSequential
from petals.models.falcon.config import DistributedFalconConfig
from petals.utils.auto_config import DefaultRevisionMixin

logger = get_logger(__name__)


class DistributedFalconModel(FromPretrainedMixin, PTuneMixin, FalconModel):
class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, FalconModel):
"""FalconModel, but all transformer layers are hosted by the swarm"""

_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
Expand Down Expand Up @@ -111,9 +112,8 @@ def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with Remo
return nn.Identity()


class DistributedFalconForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, FalconForCausalLM):
class DistributedFalconForCausalLM(DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, FalconForCausalLM):
_keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing
# _keys_to_ignore_on_load_missing += [r"^lm_head\."] # Missing since they are shared with input embeddings
_keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected

config_class = DistributedFalconConfig
Expand All @@ -130,7 +130,9 @@ def get_output_embeddings(self):
return self.lm_head


class DistributedFalconForSequenceClassification(FromPretrainedMixin, FalconForSequenceClassification):
class DistributedFalconForSequenceClassification(
DefaultRevisionMixin, FromPretrainedMixin, FalconForSequenceClassification
):
_keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected

Expand Down
39 changes: 34 additions & 5 deletions src/petals/utils/auto_config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
import re
from dataclasses import dataclass
from typing import Optional, Type, Union

from hivemind import get_logger
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel

from petals.utils.hf_auth import always_needs_auth

logger = get_logger(__name__)


@dataclass
class _ModelClasses:
Expand Down Expand Up @@ -49,17 +51,44 @@ def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike, None], *arg
return proper_cls.from_pretrained(model_name_or_path, *args, **kwargs)


class AutoDistributedConfig(_AutoDistributedBase):
class DefaultRevisionMixin:
"""
Petals only supports Falcon loaded in the new in-library format (transformers.FalconModel).
TII models were recently converted to this format but then reverted back due to compatibility issues.
We chose to support only the new format since HF staff promised to eventually convert these models
to the new format again, see https://huggingface.co/tiiuae/falcon-40b/discussions/90#64b4d23bf44fd957492f7602
Until it happens, we override the default `main` revision for the TII repos with the commit
pointing out to the model in the in-library format.
"""

DEFAULT_REVISIONS = {
"tiiuae/falcon-40b": "f1ba7d328c06aa6fbb4a8afd3c756f46d7e6b232",
"tiiuae/falcon-40b-instruct": "7475ff8cfc36ed9a962b658ae3c33391566a85a5",
"tiiuae/falcon-7b": "4e2d06f0a7c6370ebabbc30c6f59377ae8f73d76",
"tiiuae/falcon-7b-instruct": "f8dac3fff96d5debd43edf56fb4e1abcfffbef28",
}

@classmethod
def from_pretrained(
cls, model_name_or_path: Union[str, os.PathLike, None], *args, revision: Optional[str] = None, **kwargs
):
if revision is None and model_name_or_path in cls.DEFAULT_REVISIONS:
revision = cls.DEFAULT_REVISIONS[model_name_or_path]
logger.info(f"Loading {model_name_or_path}, revision {revision}")
return super().from_pretrained(model_name_or_path, *args, revision=revision, **kwargs)


class AutoDistributedConfig(DefaultRevisionMixin, _AutoDistributedBase):
_mapping_field = "config"


class AutoDistributedModel(_AutoDistributedBase):
class AutoDistributedModel(DefaultRevisionMixin, _AutoDistributedBase):
_mapping_field = "model"


class AutoDistributedModelForCausalLM(_AutoDistributedBase):
class AutoDistributedModelForCausalLM(DefaultRevisionMixin, _AutoDistributedBase):
_mapping_field = "model_for_causal_lm"


class AutoDistributedModelForSequenceClassification(_AutoDistributedBase):
class AutoDistributedModelForSequenceClassification(DefaultRevisionMixin, _AutoDistributedBase):
_mapping_field = "model_for_sequence_classification"

0 comments on commit cac654a

Please sign in to comment.