Skip to content

Commit

Permalink
Set default revision for tiiuae/* repos to models in the in-library
Browse files Browse the repository at this point in the history
format
  • Loading branch information
borzunov committed Sep 3, 2023
1 parent 97dd376 commit 72033d1
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
fail-fast: false
runs-on: ${{ matrix.os }}-latest
timeout-minutes: 15
timeout-minutes: 20
steps:
- name: Increase swap space
if: ${{ matrix.os == 'ubuntu' }}
Expand Down
29 changes: 28 additions & 1 deletion src/petals/models/falcon/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,34 @@
logger = get_logger(__name__)


class DistributedFalconConfig(FalconConfig, ClientConfig, PTuneConfig, LMHeadConfig):
class DefaultRevisionMixin:
"""
Petals only supports Falcon loaded in the new in-library format (transformers.FalconModel).
TII models were recently converted to this format but then reverted back due to compatibility issues.
We chose to support only the new format since HF staff promised to eventually convert these models
to the new format again, see https://huggingface.co/tiiuae/falcon-40b/discussions/90#64b4d23bf44fd957492f7602
Until it happens, we override the default `main` revision for the TII repos with the commit
pointing out to the model in the in-library format.
"""

DEFAULT_REVISIONS = {
"tiiuae/falcon-40b": "f1ba7d328c06aa6fbb4a8afd3c756f46d7e6b232",
"tiiuae/falcon-40b-instruct": "7475ff8cfc36ed9a962b658ae3c33391566a85a5",
"tiiuae/falcon-7b": "4e2d06f0a7c6370ebabbc30c6f59377ae8f73d76",
"tiiuae/falcon-7b-instruct": "f8dac3fff96d5debd43edf56fb4e1abcfffbef28",
}

@classmethod
def from_pretrained(
cls, model_name_or_path: Union[str, os.PathLike, None], *args, revision: Optional[str] = None, **kwargs
):
if revision is None and model_name_or_path in cls.DEFAULT_REVISIONS:
revision = cls.DEFAULT_REVISIONS[model_name_or_path]
logger.info(f"Using revision {revision} for {model_name_or_path}")
return super().from_pretrained(model_name_or_path, *args, revision=revision, **kwargs)


class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig, PTuneConfig, LMHeadConfig):
block_class = WrappedFalconBlock
attn_class = FalconAttention
block_prefix = "transformer.h"
Expand Down
11 changes: 6 additions & 5 deletions src/petals/models/falcon/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
from petals.client.ptune import PTuneMixin
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
from petals.client.remote_sequential import RemoteSequential
from petals.models.falcon.config import DistributedFalconConfig
from petals.models.falcon.config import DefaultRevisionMixin, DistributedFalconConfig

logger = get_logger(__name__)


class DistributedFalconModel(FromPretrainedMixin, PTuneMixin, FalconModel):
class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, FalconModel):
"""FalconModel, but all transformer layers are hosted by the swarm"""

_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
Expand Down Expand Up @@ -111,9 +111,8 @@ def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with Remo
return nn.Identity()


class DistributedFalconForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, FalconForCausalLM):
class DistributedFalconForCausalLM(DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, FalconForCausalLM):
_keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing
# _keys_to_ignore_on_load_missing += [r"^lm_head\."] # Missing since they are shared with input embeddings
_keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected

config_class = DistributedFalconConfig
Expand All @@ -130,7 +129,9 @@ def get_output_embeddings(self):
return self.lm_head


class DistributedFalconForSequenceClassification(FromPretrainedMixin, FalconForSequenceClassification):
class DistributedFalconForSequenceClassification(
DefaultRevisionMixin, FromPretrainedMixin, FalconForSequenceClassification
):
_keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected

Expand Down

0 comments on commit 72033d1

Please sign in to comment.