Skip to content

Commit

Permalink
Update transformers to 4.31.0 and peft to 0.4.0 (#371)
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Jul 19, 2023
1 parent 1ab35c2 commit c735dd7
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 52 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ '3.7', '3.8', '3.9', '3.10' ]
python-version: [ '3.8', '3.9', '3.10' ]
fail-fast: false
timeout-minutes: 15
steps:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ print(tokenizer.decode(outputs[0])) # A cat sat on a mat...

### Connect your GPU and increase Petals capacity

Run these commands in an [Anaconda](https://www.anaconda.com) env (requires Linux and Python 3.7+):
Run these commands in an [Anaconda](https://www.anaconda.com) env (requires Linux and Python 3.8+):

```bash
conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
Expand Down
8 changes: 4 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ classifiers =
Intended Audience :: Science/Research
License :: OSI Approved :: MIT License
Programming Language :: Python :: 3
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Topic :: Scientific/Engineering
Topic :: Scientific/Engineering :: Mathematics
Topic :: Scientific/Engineering :: Artificial Intelligence
Expand All @@ -29,14 +29,14 @@ classifiers =
package_dir =
= src
packages = find:
python_requires = >=3.7
python_requires = >=3.8
install_requires =
torch>=1.12
bitsandbytes==0.40.1.post1
accelerate>=0.16.0,<0.21.0
huggingface-hub>=0.11.1,<1.0.0
tokenizers>=0.13.3
transformers>=4.30.1,<4.31.0
transformers>=4.31.0,<5.0.0
speedtest-cli==2.1.3
pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind==1.1.8
hivemind==1.1.8
Expand All @@ -46,7 +46,7 @@ install_requires =
cpufeature>=0.2.0
packaging>=20.9
sentencepiece>=0.1.99
peft@git+https://github.com/huggingface/peft@5884bdbea49e5e71e2cd06ecfa484bb635063735
peft>=0.4.0
safetensors>=0.3.1
Dijkstar>=2.6.0

Expand Down
4 changes: 2 additions & 2 deletions src/petals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
assert (
version.parse("4.30.1") <= version.parse(transformers.__version__) < version.parse("5.0.0")
), "Please install a proper transformers version: pip install transformers>=4.30.1,<5.0.0"
version.parse("4.31.0") <= version.parse(transformers.__version__) < version.parse("5.0.0")
), "Please install a proper transformers version: pip install transformers>=4.31.0,<5.0.0"


def _override_bfloat16_mode_default():
Expand Down
2 changes: 1 addition & 1 deletion src/petals/cli/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def main():
parser.add_argument("--mean_balance_check_period", type=float, default=60,
help="Check the swarm's balance every N seconds (and rebalance it if necessary)")

parser.add_argument("--use_auth_token", action='store_true', help="auth token for from_pretrained")
parser.add_argument("--token", action='store_true', help="Hugging Face hub auth token for .from_pretrained()")
parser.add_argument('--quant_type', type=str, default=None, choices=[choice.name.lower() for choice in QuantType],
help="Quantize blocks to 8-bit (int8 from the LLM.int8() paper) or "
"4-bit (nf4 from the QLoRA paper) formats to save GPU memory. "
Expand Down
16 changes: 4 additions & 12 deletions src/petals/models/bloom/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
"""BloomModel, but all transformer layers are hosted by the swarm"""

_keys_to_ignore_on_load_missing = (
BloomModel._keys_to_ignore_on_load_missing + PTuneMixin._keys_to_ignore_on_load_missing
)
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = [r"^h\."]

config_class = DistributedBloomConfig
Expand Down Expand Up @@ -93,11 +91,8 @@ def forward(


class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, BloomForCausalLM):
_keys_to_ignore_on_load_missing = (
BloomForCausalLM._keys_to_ignore_on_load_missing
+ DistributedBloomModel._keys_to_ignore_on_load_missing
+ [r"^lm_head\."] # Missing since they are shared with input embeddings
)
_keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_missing += [r"^lm_head\."] # Missing since they are shared with input embeddings
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected

config_class = DistributedBloomConfig
Expand All @@ -115,10 +110,7 @@ def get_output_embeddings(self):


class DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSequenceClassification):
_keys_to_ignore_on_load_missing = (
BloomForSequenceClassification._keys_to_ignore_on_load_missing
+ DistributedBloomModel._keys_to_ignore_on_load_missing
)
_keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected

config_class = DistributedBloomConfig
Expand Down
9 changes: 4 additions & 5 deletions src/petals/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
"""LlamaModel, but all transformer layers are hosted by the swarm"""

_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = LlamaModel._keys_to_ignore_on_load_unexpected + [r"^model\.layers\."]
_keys_to_ignore_on_load_unexpected = [r"^model\.layers\."]

config_class = DistributedLlamaConfig

Expand Down Expand Up @@ -115,6 +115,8 @@ class DistributedLlamaForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Ll
def __init__(self, config: DistributedLlamaConfig):
LlamaPreTrainedModel.__init__(self, config)
self.model = DistributedLlamaModel(config)
self.pretraining_tp = config.pretraining_tp
self.vocab_size = config.vocab_size
self.lm_head = LMHead(config)

# Initialize weights and apply final processing
Expand All @@ -129,10 +131,7 @@ def transformer(self) -> DistributedLlamaModel: # For compatibility with Remote


class DistributedLlamaForSequenceClassification(FromPretrainedMixin, LlamaForSequenceClassification):
_keys_to_ignore_on_load_missing = (
LlamaForSequenceClassification._keys_to_ignore_on_load_missing
+ DistributedLlamaModel._keys_to_ignore_on_load_missing
)
_keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected

config_class = DistributedLlamaConfig
Expand Down
20 changes: 10 additions & 10 deletions src/petals/server/from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ def load_pretrained_block(
config: Optional[PretrainedConfig] = None,
torch_dtype: Union[torch.dtype, str] = "auto",
revision: Optional[str] = None,
use_auth_token: Optional[str] = None,
token: Optional[str] = None,
cache_dir: Optional[str] = None,
max_disk_space: Optional[int] = None,
) -> nn.Module:
if config is None:
config = AutoDistributedConfig.from_pretrained(model_name, use_auth_token=use_auth_token)
config = AutoDistributedConfig.from_pretrained(model_name, token=token)
if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR

Expand All @@ -54,7 +54,7 @@ def load_pretrained_block(
model_name,
block_prefix,
revision=revision,
use_auth_token=use_auth_token,
token=token,
cache_dir=cache_dir,
max_disk_space=max_disk_space,
)
Expand Down Expand Up @@ -82,12 +82,12 @@ def _load_state_dict_from_repo(
block_prefix: str,
*,
revision: Optional[str] = None,
use_auth_token: Optional[str] = None,
token: Optional[str] = None,
cache_dir: str,
max_disk_space: Optional[int] = None,
) -> StateDict:
index_file = get_file_from_repo(
model_name, filename="pytorch_model.bin.index.json", use_auth_token=use_auth_token, cache_dir=cache_dir
model_name, filename="pytorch_model.bin.index.json", use_auth_token=token, cache_dir=cache_dir
)
if index_file is not None: # Sharded model
with open(index_file) as f:
Expand All @@ -107,7 +107,7 @@ def _load_state_dict_from_repo(
model_name,
filename,
revision=revision,
use_auth_token=use_auth_token,
token=token,
cache_dir=cache_dir,
max_disk_space=max_disk_space,
)
Expand All @@ -125,7 +125,7 @@ def _load_state_dict_from_file(
filename: str,
*,
revision: Optional[str] = None,
use_auth_token: Optional[str] = None,
token: Optional[str] = None,
cache_dir: str,
max_disk_space: Optional[int] = None,
delay: float = 30,
Expand All @@ -137,7 +137,7 @@ def _load_state_dict_from_file(
model_name,
filename,
revision=revision,
use_auth_token=use_auth_token,
use_auth_token=token,
cache_dir=cache_dir,
local_files_only=True,
)
Expand All @@ -151,7 +151,7 @@ def _load_state_dict_from_file(
try:
with allow_cache_writes(cache_dir):
url = hf_hub_url(model_name, filename, revision=revision)
file_size = get_hf_file_metadata(url, token=use_auth_token).size
file_size = get_hf_file_metadata(url, token=token).size
if file_size is not None:
free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
else:
Expand All @@ -161,7 +161,7 @@ def _load_state_dict_from_file(
model_name,
filename,
revision=revision,
use_auth_token=use_auth_token,
use_auth_token=token,
cache_dir=cache_dir,
local_files_only=False,
)
Expand Down
16 changes: 8 additions & 8 deletions src/petals/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
balance_quality: float = 0.75,
mean_balance_check_period: float = 120,
mean_block_selection_delay: float = 2.5,
use_auth_token: Optional[str] = None,
token: Optional[str] = None,
quant_type: Optional[QuantType] = None,
tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
skip_reachability_check: bool = False,
Expand All @@ -98,14 +98,14 @@ def __init__(
self.compression = compression
self.stats_report_interval, self.update_period = stats_report_interval, update_period
self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
self.revision, self.use_auth_token = revision, use_auth_token
self.revision, self.token = revision, token

if custom_module_path is not None:
add_custom_models_from_file(custom_module_path)

self.block_config = AutoDistributedConfig.from_pretrained(
converted_model_name_or_path,
use_auth_token=use_auth_token,
token=token,
revision=revision,
)

Expand Down Expand Up @@ -271,7 +271,7 @@ def _choose_num_blocks(self) -> int:
self.block_config,
self.torch_dtype,
self.adapters,
use_auth_token=self.use_auth_token,
token=self.token,
cache_dir=self.cache_dir,
max_disk_space=self.max_disk_space,
)
Expand Down Expand Up @@ -316,7 +316,7 @@ def run(self):
prefetch_batches=self.prefetch_batches,
sender_threads=self.sender_threads,
revision=self.revision,
use_auth_token=self.use_auth_token,
token=self.token,
quant_type=self.quant_type,
tensor_parallel_devices=self.tensor_parallel_devices,
should_validate_reachability=self.should_validate_reachability,
Expand Down Expand Up @@ -409,7 +409,7 @@ def create(
update_period: float,
expiration: Optional[float],
revision: Optional[str],
use_auth_token: Optional[str],
token: Optional[str],
quant_type: QuantType,
tensor_parallel_devices: Sequence[torch.device],
should_validate_reachability: bool,
Expand Down Expand Up @@ -443,7 +443,7 @@ def create(
config=block_config,
torch_dtype=torch_dtype,
revision=revision,
use_auth_token=use_auth_token,
token=token,
cache_dir=cache_dir,
max_disk_space=max_disk_space,
)
Expand All @@ -456,7 +456,7 @@ def create(
quant_type,
adapters=server_info.adapters,
freeze=True,
use_auth_token=use_auth_token,
token=token,
cache_dir=cache_dir,
max_disk_space=max_disk_space,
)
Expand Down
23 changes: 15 additions & 8 deletions src/petals/utils/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,20 @@ def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", d
return tensors


def get_adapter_from_repo(repo_id: str, block_idx: Optional[int] = None, device: Optional[int] = None, **kwargs):
config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
def get_adapter_from_repo(
repo_id: str,
block_idx: Optional[int] = None,
device: Optional[int] = None,
*,
token: Optional[str] = None,
**kwargs,
):
config_path = get_file_from_repo(repo_id, CONFIG_NAME, use_auth_token=token, **kwargs)
if config_path is None:
raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
config = PeftConfig.from_json_file(config_path)

weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, **kwargs)
weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, use_auth_token=token, **kwargs)
if weight_path is None:
raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
if block_idx is None:
Expand All @@ -65,7 +72,7 @@ def load_peft(
device: Optional[int] = None,
*,
revision: Optional[str] = None,
use_auth_token: Optional[str] = None,
token: Optional[str] = None,
cache_dir: str,
max_disk_space: Optional[int] = None,
delay: float = 30,
Expand All @@ -82,7 +89,7 @@ def load_peft(
block_idx,
device,
revision=revision,
use_auth_token=use_auth_token,
token=token,
cache_dir=cache_dir,
local_files_only=False,
)
Expand All @@ -93,9 +100,9 @@ def load_peft(
try:
with allow_cache_writes(cache_dir):
config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
config_file_size = get_hf_file_metadata(config_url, token=use_auth_token).size
config_file_size = get_hf_file_metadata(config_url, token=token).size
weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
weight_file_size = get_hf_file_metadata(weight_url, token=use_auth_token).size
weight_file_size = get_hf_file_metadata(weight_url, token=token).size

file_size = config_file_size + weight_file_size
if file_size is not None:
Expand All @@ -108,7 +115,7 @@ def load_peft(
block_idx,
device,
revision=revision,
use_auth_token=use_auth_token,
token=token,
cache_dir=cache_dir,
local_files_only=False,
)
Expand Down

0 comments on commit c735dd7

Please sign in to comment.