Skip to content

Commit

Permalink
Add Falcon support (bigscience-workshop#499)
Browse files Browse the repository at this point in the history
This PR adds:

- Support for models based on `transformers.FalconModel` (the in-library format for Falcon). Tested on Falcon-40B.
- CI tests for Falcon-RW-1B.
- `--throughput dry_run` option to evaluate throughput and exit right away (implemented by @mryab).

Limitations:

- Backward pass support is broken for now, will be fixed in bigscience-workshop#500.

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
  • Loading branch information
2 people authored and Dobromir Popov committed Sep 5, 2023
1 parent 268aa02 commit 239a23b
Show file tree
Hide file tree
Showing 9 changed files with 356 additions and 15 deletions.
7 changes: 6 additions & 1 deletion .github/workflows/run-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ jobs:
- { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' }
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' }
- { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.8' }
- { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' }
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
fail-fast: false
runs-on: ${{ matrix.os }}-latest
timeout-minutes: 15
timeout-minutes: 20
steps:
- name: Increase swap space
if: ${{ matrix.os == 'ubuntu' }}
Expand Down Expand Up @@ -93,6 +95,9 @@ jobs:
# [Step 2] Run PyTest
# Share disk cache between Petals servers, clients, and HF Transformers
export TRANSFORMERS_CACHE=~/.cache/petals
# Necessary for @pytest.mark.forked to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93
export no_proxy=*
export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
Expand Down
5 changes: 3 additions & 2 deletions src/petals/cli/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,13 @@ def main():
"and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")

parser.add_argument('--throughput',
type=lambda value: value if value in ['auto', 'eval'] else float(value),
type=lambda value: value if value in ['auto', 'eval', 'dry_run'] else float(value),
default='auto',
help='Expected server throughput (a float measured in RPS). '
'If set to "auto" (default), the script evaluates network and compute throughput '
'on the first run and uses these estimates for future runs. '
'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
'If set to "eval", the script re-evaluates the throughput and overrides the cache. '
'If set to "dry_run", the script re-evaluates the throughput and exits.')
parser.add_argument('--update_period', type=float, required=False, default=120,
help='Server will report blocks to DHT once in this many seconds')
parser.add_argument('--expiration', type=float, required=False, default=None,
Expand Down
1 change: 1 addition & 0 deletions src/petals/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from petals.models.bloom import *
from petals.models.falcon import *
from petals.models.llama import *
15 changes: 15 additions & 0 deletions src/petals/models/falcon/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from petals.models.falcon.block import WrappedFalconBlock
from petals.models.falcon.config import DistributedFalconConfig
from petals.models.falcon.model import (
DistributedFalconForCausalLM,
DistributedFalconForSequenceClassification,
DistributedFalconModel,
)
from petals.utils.auto_config import register_model_classes

register_model_classes(
config=DistributedFalconConfig,
model=DistributedFalconModel,
model_for_causal_lm=DistributedFalconForCausalLM,
model_for_sequence_classification=DistributedFalconForSequenceClassification,
)
94 changes: 94 additions & 0 deletions src/petals/models/falcon/block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
Falcon intermediate layer
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py
See commit history for authorship.
"""
from typing import Optional, Tuple

import torch
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor

KVCache = Tuple[torch.Tensor, torch.Tensor]


class WrappedFalconBlock(FalconDecoderLayer):
def forward(
self,
hidden_states: torch.Tensor,
*args,
attention_mask: Optional[torch.Tensor] = None,
alibi: Optional[torch.Tensor] = None,
layer_past: Optional[KVCache] = None,
use_cache: bool = False,
**kwargs
):
batch_size, seq_length = hidden_states.shape[:2]

if layer_past is not None:
layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past)
past_length = 0 if layer_past is None else layer_past[0].shape[1]
seq_length_with_past = seq_length + past_length

attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
if alibi is None and self.config.alibi:
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)

outputs = super().forward(
hidden_states,
*args,
attention_mask=attention_mask,
alibi=alibi,
layer_past=layer_past,
use_cache=use_cache,
**kwargs
)

if use_cache:
present_key_value = outputs[-1]
present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value)
outputs = outputs[:-1] + (present_key_value,)

return outputs

def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache:
key_states, value_states = key_value

key_states = key_states.permute(0, 2, 1)
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]

if self.config.new_decoder_architecture:
key_states = self._expand_states(key_states)
value_states = self._expand_states(value_states)

return (key_states, value_states)

def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache:
key_states, value_states = key_value

if self.config.new_decoder_architecture:
key_states = self._collapse_states(key_states)
value_states = self._collapse_states(value_states)

assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
key_states = key_states.permute(0, 2, 1)

return (key_states, value_states)

def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
batch_size_x_num_kv_heads, seq_len, head_dim = state.shape
batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads

state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim)
state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy
state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy
return state

def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
batch_size_x_num_attn_heads, seq_len, head_dim = state.shape
batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads

state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim)
state = state[:, :, 0]
state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
return state
45 changes: 45 additions & 0 deletions src/petals/models/falcon/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os
from typing import Optional, Union

from hivemind import get_logger
from transformers.models.falcon import FalconConfig
from transformers.models.falcon.modeling_falcon import FalconAttention

from petals.client.config import ClientConfig
from petals.client.lm_head import LMHeadConfig
from petals.client.ptune import PTuneConfig
from petals.models.falcon.block import WrappedFalconBlock
from petals.utils.auto_config import DefaultRevisionMixin

logger = get_logger(__name__)


class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig, PTuneConfig, LMHeadConfig):
block_class = WrappedFalconBlock
attn_class = FalconAttention
block_prefix = "transformer.h"

@property
def num_key_value_groups(self) -> int:
if self.new_decoder_architecture:
return self.num_attention_heads // self.num_kv_heads
if self.multi_query:
return self.num_attention_heads
return 1

@classmethod
def from_pretrained(
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
):
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
if loading_from_repo and dht_prefix is None:
dht_prefix = str(model_name_or_path)
dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts
dht_prefix = dht_prefix.replace(".", "-")
logger.info(f"Using DHT prefix: {dht_prefix}")

result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
config = result[0] if isinstance(result, tuple) else result
if config.pad_token_id is None:
config.pad_token_id = 0
return result
149 changes: 149 additions & 0 deletions src/petals/models/falcon/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from typing import Optional

import hivemind
import torch
import torch.nn as nn
from hivemind.utils.logging import get_logger
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.falcon import (
FalconForCausalLM,
FalconForSequenceClassification,
FalconModel,
FalconPreTrainedModel,
)

from petals.client.from_pretrained import FromPretrainedMixin
from petals.client.lm_head import LMHead
from petals.client.ptune import PTuneMixin
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
from petals.client.remote_sequential import RemoteSequential
from petals.models.falcon.config import DistributedFalconConfig
from petals.utils.auto_config import DefaultRevisionMixin

logger = get_logger(__name__)


class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, FalconModel):
"""FalconModel, but all transformer layers are hosted by the swarm"""

_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = [r"^transformer\.h\."]

config_class = DistributedFalconConfig

def __init__(self, config: DistributedFalconConfig, *, dht: Optional[hivemind.DHT] = None):
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
super().__init__(config)
assert len(self.h) == 0
config.num_hidden_layers = n_layer

self.h = RemoteSequential(config, dht=dht)

self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
self.init_prompts(config)

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[RemotePastKeyValues] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

# The causal mask will be added on the server-side
assert (
attention_mask is None or (attention_mask == 1).all()
), f"Custom attention masks are not supported, {attention_mask=}"
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
assert not output_attentions, f"{output_attentions=} is not supported"
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
assert return_dict is None or return_dict, f"{return_dict=} is not supported"

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)

if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0:
batch_size = inputs_embeds.shape[0]
prompts, intermediate_prompts = self.get_prompt(batch_size)
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
else:
prompts = intermediate_prompts = None

hidden_states = self.word_embeddings_layernorm(inputs_embeds)
output_shape = input_shape + (hidden_states.size(-1),)

hidden_states = self.h(
hidden_states,
prompts=intermediate_prompts,
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
)

# Remove prefix
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
hidden_states = hidden_states[:, self.pre_seq_len :]

# Add last hidden state
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=RemotePastKeyValues(),
hidden_states=None,
attentions=None,
)

@property
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
return nn.Identity()


class DistributedFalconForCausalLM(DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, FalconForCausalLM):
_keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected

config_class = DistributedFalconConfig

def __init__(self, config: DistributedFalconConfig):
FalconPreTrainedModel.__init__(self, config)
self.transformer = DistributedFalconModel(config)
self.lm_head = LMHead(config)

# Initialize weights and apply final processing
self.post_init()

def get_output_embeddings(self):
return self.lm_head


class DistributedFalconForSequenceClassification(
DefaultRevisionMixin, FromPretrainedMixin, FalconForSequenceClassification
):
_keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected

config_class = DistributedFalconConfig

def __init__(self, config: DistributedFalconConfig):
FalconPreTrainedModel.__init__(self, config)
self.num_labels = config.num_labels

self.transformer = DistributedFalconModel(config)
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)

# Initialize weights and apply final processing
self.post_init()
16 changes: 9 additions & 7 deletions src/petals/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import multiprocessing as mp
import os
import random
import sys
import threading
import time
from typing import Dict, List, Optional, Sequence, Union
Expand Down Expand Up @@ -186,10 +187,7 @@ def __init__(
check_device_balance(self.tensor_parallel_devices)

if quant_type is None:
if device.type == "cuda":
quant_type = QuantType.NF4 if self.block_config.model_type == "llama" else QuantType.INT8
else:
quant_type = QuantType.NONE
quant_type = QuantType.NF4 if device.type == "cuda" else QuantType.NONE
self.quant_type = quant_type
logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")

Expand Down Expand Up @@ -234,8 +232,9 @@ def __init__(
self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")

assert isinstance(throughput, float) or throughput in ["auto", "eval"]
if throughput in ["auto", "eval"]:
assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"]
if throughput in ["auto", "eval", "dry_run"]:
force_eval = throughput in ["eval", "dry_run"]
throughput_info = get_server_throughput(
converted_model_name_or_path,
self.block_config,
Expand All @@ -245,9 +244,12 @@ def __init__(
quant_type=quant_type,
tensor_parallel_devices=self.tensor_parallel_devices,
reachable_via_relay=reachable_via_relay,
force_eval=(throughput == "eval"),
force_eval=force_eval,
cache_dir=cache_dir,
)
if throughput == "dry_run":
logger.info("Finished estimating throughput, exiting")
sys.exit(0)
else:
throughput_info = {"throughput": throughput}
self.server_info = ServerInfo(
Expand Down
Loading

0 comments on commit 239a23b

Please sign in to comment.