Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize meta tensor pipeline #199

Merged
merged 10 commits into from
Jun 26, 2023
1 change: 1 addition & 0 deletions examples/aml/text-generation-bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
mii_configs = {
"dtype": "fp16",
"tensor_parallel": 8,
"meta_tensor": True,
}
name = "microsoft/bloom-deepspeed-inference-fp16"

Expand Down
1 change: 1 addition & 0 deletions examples/local/text-generation-bloom-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"dtype": "fp16",
"tensor_parallel": 8,
"port_number": 50950,
"meta_tensor": True,
}
name = "microsoft/bloom-deepspeed-inference-fp16"

Expand Down
9 changes: 9 additions & 0 deletions mii/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class MIIConfig(BaseModel):
tensor_parallel: int = 1
port_number: int = 50050
dtype: DtypeEnum = torch.float32
meta_tensor: bool = False
load_with_sys_mem: bool = False
enable_cuda_graph: bool = False
checkpoint_dict: Union[dict, None] = None
Expand Down Expand Up @@ -99,6 +100,14 @@ def auto_enable_load_balancing(cls, values):
values["enable_load_balancing"] = True
return values

@root_validator
def meta_tensor_or_sys_mem(cls, values):
if values.get("meta_tensor") and values.get("load_with_sys_mem"):
raise ValueError(
"`meta_tensor` and `load_with_sys_mem` cannot be active at the same time."
)
return values

class Config:
validate_all = True
validate_assignment = True
Expand Down
7 changes: 2 additions & 5 deletions mii/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,16 @@ class Tasks(enum.Enum):
class ModelProvider(enum.Enum):
HUGGING_FACE = 1
ELEUTHER_AI = 2
HUGGING_FACE_LLM = 3
DIFFUSERS = 4
DIFFUSERS = 3


MODEL_PROVIDER_NAME_HF = "hugging-face"
MODEL_PROVIDER_NAME_EA = "eleuther-ai"
MODEL_PROVIDER_NAME_HF_LLM = "hugging-face-llm"
MODEL_PROVIDER_NAME_DIFFUSERS = "diffusers"

MODEL_PROVIDER_MAP = {
MODEL_PROVIDER_NAME_HF: ModelProvider.HUGGING_FACE,
MODEL_PROVIDER_NAME_EA: ModelProvider.ELEUTHER_AI,
MODEL_PROVIDER_NAME_HF_LLM: ModelProvider.HUGGING_FACE_LLM,
MODEL_PROVIDER_NAME_DIFFUSERS: ModelProvider.DIFFUSERS
}

Expand All @@ -61,8 +58,8 @@ class ModelProvider(enum.Enum):
'gpt_neo': ModelProvider.HUGGING_FACE,
'gptj': ModelProvider.HUGGING_FACE,
'opt': ModelProvider.HUGGING_FACE,
'bloom': ModelProvider.HUGGING_FACE,
'gpt-neox': ModelProvider.ELEUTHER_AI,
'bloom': ModelProvider.HUGGING_FACE_LLM,
'stable-diffusion': ModelProvider.DIFFUSERS
}

Expand Down
18 changes: 8 additions & 10 deletions mii/models/load_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ def load_models(task_name,
assert mii_config.dtype == torch.half or mii_config.dtype == torch.int8, "Bloom models only support fp16/int8"
assert mii_config.enable_cuda_graph == False, "Bloom models do no support Cuda Graphs"
inference_pipeline = hf_provider(model_path, model_name, task_name, mii_config)
if mii_config.meta_tensor:
inf_config["checkpoint"] = inference_pipeline.checkpoint_dict
if mii_config.dtype == torch.int8:
# Support for older DeepSpeed versions
if "enable_qkv_quantization" in inspect.signature(
deepspeed.init_inference).parameters:
inf_config["enable_qkv_quantization"] = True
elif provider == mii.constants.ModelProvider.ELEUTHER_AI:
from mii.models.providers.eleutherai import eleutherai_provider
assert mii_config.dtype == torch.half, "gpt-neox only support fp16"
Expand All @@ -57,16 +64,6 @@ def load_models(task_name,
mii_config)
inf_config["training_mp_size"] = 2
inf_config["config"] = inference_pipeline.neox_args
elif provider == mii.constants.ModelProvider.HUGGING_FACE_LLM:
from mii.models.providers.llm import load_hf_llm
assert mii_config.dtype == torch.half or mii_config.dtype == torch.int8, "Bloom models only support fp16/int8"
assert mii_config.enable_cuda_graph == False, "Bloom models do no support Cuda Graphs"
inference_pipeline = load_hf_llm(model_path, model_name, task_name, mii_config)
inf_config["checkpoint"] = inference_pipeline.checkpoint_dict
if mii_config.dtype == torch.int8:
if "enable_qkv_quantization" in inspect.signature(
deepspeed.init_inference).parameters:
inf_config["enable_qkv_quantization"] = True
elif provider == mii.constants.ModelProvider.DIFFUSERS:
from mii.models.providers.diffusers import diffusers_provider
inference_pipeline = diffusers_provider(model_path,
Expand All @@ -92,6 +89,7 @@ def load_models(task_name,
inference_pipeline.model = engine

elif ds_zero:
assert not mii_config.meta_tensor, "ZeRO-Inference does not support meta tensors"
ds_config = DeepSpeedConfig(ds_config_path)
#TODO: don't read ds-config from disk, we should pass this around as a dict instead
ds_config_dict = json.load(open(ds_config_path, 'r'))
Expand Down
196 changes: 185 additions & 11 deletions mii/models/providers/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,196 @@

# DeepSpeed Team
import os
import json
import torch
from transformers import pipeline
import deepspeed
from deepspeed.inference.engine import InferenceEngine
from deepspeed import OnDevice
from pathlib import Path
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoConfig
from transformers.utils.hub import EntryNotFoundError
from transformers.modeling_utils import get_checkpoint_shard_files
from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME

from mii.utils import mii_cache_path

def hf_provider(model_path, model_name, task_name, mii_config):
if mii_config.load_with_sys_mem:
try:
from transformers.utils import cached_path, hf_bucket_url
USE_NEW_HF_CACHE = False
except ImportError:
from huggingface_hub import snapshot_download
USE_NEW_HF_CACHE = True


class MetaTensorPipeline(object):
"""
Class for loading HuggingFace models using meta tensors
"""
def __init__(self, model, tokenizer, checkpoint_dict):
self.model = model
self.tokenizer = tokenizer
self.checkpoint_dict = checkpoint_dict

def __call__(self, inputs, **kwargs):
device = get_device()
torch.cuda.set_device(device)
if isinstance(self.model, InferenceEngine):
self.model = self.model.module

# expand proto list into py-list
inputs = [i for i in inputs]
tokens = self.tokenizer.batch_encode_plus(inputs,
return_tensors="pt",
padding=True)
for t in tokens:
if torch.is_tensor(tokens[t]):
tokens[t] = tokens[t].to(device)

greedy_output = self.model.generate(**tokens, **kwargs)
outputs = self.tokenizer.batch_decode(greedy_output, skip_special_tokens=True)

# construct output to align w. HF pipeline
output_dicts = []
for output in outputs:
output_dicts.append([{'generated_text': output}])

return output_dicts


def get_device(load_with_sys_mem=False):
if load_with_sys_mem:
device = torch.device("cpu")
else:
local_rank = int(os.getenv('LOCAL_RANK', '0'))
device = torch.device(f"cuda:{local_rank}")
inference_pipeline = pipeline(
task_name,
model=model_name,
device=device,
framework="pt",
use_auth_token=mii_config.hf_auth_token,
torch_dtype=mii_config.dtype,
)
return device


def _attempt_load(load_fn, model_name, cache_path, kwargs={}):
try:
value = load_fn(model_name, **kwargs)
except OSError:
print(f'Attempted load but failed, retrying using cache_dir={cache_path}')
value = load_fn(model_name, cache_dir=cache_path, **kwargs)
return value


def get_checkpoint_files(pretrained_model_name_or_path):
cache_dir = None
is_sharded = False
revision = None
local_files_only = False

filename = WEIGHTS_NAME
archive_file = hf_bucket_url(pretrained_model_name_or_path,
filename=filename,
revision=revision)

try:
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
return [resolved_archive_file]

except (EntryNotFoundError, FileNotFoundError):
if filename == WEIGHTS_NAME:
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=WEIGHTS_INDEX_NAME,
revision=revision,
)
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
is_sharded = True

if is_sharded:
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
pretrained_model_name_or_path,
resolved_archive_file,
cache_dir=cache_dir,
revision=revision
)

return resolved_archive_file


def create_checkpoint_dict(model_name, model_path, mii_config):
if USE_NEW_HF_CACHE:
model_path = snapshot_download(model_name,
cache_dir=model_path,
allow_patterns=[
"*.bin",
"*.json",
"*.pt",
],
revision=None)
if mii_config.checkpoint_dict:
mii_config.checkpoint_dict['base_dir'] = model_path
return mii_config.checkpoint_dict
elif os.path.isfile(os.path.join(model_path, "ds_inference_config.json")):
with open(os.path.join(model_path, "ds_inference_config.json")) as f:
data = json.load(f)
data["base_dir"] = model_path
return data
else:
if USE_NEW_HF_CACHE:
checkpoint_files = [
str(entry).split('/')[-1]
for entry in Path(model_path).rglob("*.[bp][it][n]") if entry.is_file()
]
else:
checkpoint_files = get_checkpoint_files(model_name)
data = {
"type": "DS_MODEL",
"checkpoints": checkpoint_files,
"version": 1.0,
"base_dir": model_path
}
return data


def load_with_meta_tensor(model_path, model_name, task_name, mii_config):
deepspeed.init_distributed('nccl')

cache_path = mii_cache_path()

tokenizer = _attempt_load(AutoTokenizer.from_pretrained,
model_name,
cache_path,
kwargs={"padding_side": 'left'})
tokenizer.pad_token = tokenizer.eos_token

config = _attempt_load(AutoConfig.from_pretrained, model_name, cache_path)

with OnDevice(dtype=torch.float16, device='meta', enabled=True):
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16)
model = model.eval()
checkpoint_dict = create_checkpoint_dict(model_name, model_path, mii_config)
torch.distributed.barrier()
inference_pipeline = MetaTensorPipeline(model=model,
tokenizer=tokenizer,
checkpoint_dict=checkpoint_dict)
return inference_pipeline


def hf_provider(model_path, model_name, task_name, mii_config):
if mii_config.meta_tensor:
return load_with_meta_tensor(model_path, model_name, task_name, mii_config)
else:
device = get_device(load_with_sys_mem=mii_config.load_with_sys_mem)
inference_pipeline = pipeline(
task_name,
model=model_name,
device=device,
framework="pt",
use_auth_token=mii_config.hf_auth_token,
torch_dtype=mii_config.dtype,
)
return inference_pipeline
Loading