Skip to content

Commit

Permalink
fix deepspeed and use cache issue (#1201)
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi <yi.a.wang@intel.com>
  • Loading branch information
sywangyi committed Jul 18, 2023
1 parent fbddd5b commit 4675d42
Showing 1 changed file with 46 additions and 40 deletions.
86 changes: 46 additions & 40 deletions workflows/chatbot/inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch.nn.functional as F
import re, os, logging
from threading import Thread
from peft import PeftModel
from transformers import (
GenerationConfig,
AutoModelForCausalLM,
Expand Down Expand Up @@ -36,7 +35,6 @@
),
}


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-bm", "--base_model_path", type=str, default="")
Expand Down Expand Up @@ -249,6 +247,8 @@ def load_model(
use_hpu_graphs=False,
cpu_jit=False,
use_cache=False,
peft_path=None,
use_deepspeed=False,
):
"""
Load the model and initialize the tokenizer.
Expand Down Expand Up @@ -328,10 +328,15 @@ def load_model(
if model.generation_config.eos_token_id is None:
model.generation_config.eos_token_id = tokenizer.eos_token_id

if peft_path:
from peft import PeftModel

model = PeftModel.from_pretrained(model, peft_path)

if device == "hpu":
model = model.eval().to("hpu")

if use_hpu_graphs:
if use_hpu_graphs and not use_deepspeed:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

model = wrap_in_hpu_graph(model)
Expand Down Expand Up @@ -393,6 +398,7 @@ def predict_stream(**params):
`bad_words_ids` (list or None): Contains a list of token IDs that should not appear in the generated text.
`force_words_ids` (list or None): Contains a list of token IDs that must be included in the generated text.
`use_hpu_graphs` (bool): Determines whether to utilize Habana Processing Units (HPUs) for accelerated generation.
`use_cache` (bool): Determines whether to utilize kv cache for accelerated generation.
Returns:
generator: A generator that yields the generated streaming text.
Expand All @@ -418,9 +424,8 @@ def predict_stream(**params):
bad_words_ids = params["bad_words_ids"] if "bad_words_ids" in params else None
force_words_ids = params["force_words_ids"] if "force_words_ids" in params else None
use_hpu_graphs = params["use_hpu_graphs"] if "use_hpu_graphs" in params else False
use_cache = params["use_kv_cache"] if "use_kv_cache" in params else False
use_cache = params["use_cache"] if "use_cache" in params else False
prompt = params["prompt"]

model = MODELS[model_name]["model"]
tokenizer = MODELS[model_name]["tokenizer"]

Expand Down Expand Up @@ -568,6 +573,7 @@ def predict(**params):
`bad_words_ids` (list or None): Contains a list of token IDs that should not appear in the generated text.
`force_words_ids` (list or None): Contains a list of token IDs that must be included in the generated text.
`use_hpu_graphs` (bool): Determines whether to utilize Habana Processing Units (HPUs) for accelerated generation.
`use_cache` (bool): Determines whether to utilize kv cache for accelerated generation.
Returns:
generator: A generator that yields the generated streaming text.
Expand All @@ -593,8 +599,7 @@ def predict(**params):
bad_words_ids = params["bad_words_ids"] if "bad_words_ids" in params else None
force_words_ids = params["force_words_ids"] if "force_words_ids" in params else None
use_hpu_graphs = params["use_hpu_graphs"] if "use_hpu_graphs" in params else False
use_cache = params["use_kv_cache"] if "use_kv_cache" in params else False

use_cache = params["use_cache"] if "use_cache" in params else False
prompt = params["prompt"]
model = MODELS[model_name]["model"]
tokenizer = MODELS[model_name]["tokenizer"]
Expand Down Expand Up @@ -680,7 +685,6 @@ def predict(**params):
generation_config.force_words_ids = force_words_ids
generation_config.num_return_sequences = num_return_sequences
generation_config.static_shapes = is_graph_optimized

with torch.no_grad():
generation_output = model.generate(
**input_tokens,
Expand Down Expand Up @@ -778,7 +782,8 @@ def main():
args.tokenizer_name if args.tokenizer_name is not None else base_model_path
)

if use_deepspeed:
if use_deepspeed and args.habana:
config = AutoConfig.from_pretrained(base_model_path, trust_remote_code=True)
with deepspeed.OnDevice(dtype=torch.bfloat16, device="cpu"):
load_model(
base_model_path,
Expand All @@ -787,22 +792,18 @@ def main():
use_hpu_graphs=args.use_hpu_graphs,
cpu_jit=args.jit,
use_cache=args.use_kv_cache,
use_deepspeed=True,
)
model = MODELS[base_model_path]["model"]
if peft_model_path:
model = PeftModel.from_pretrained(model, peft_model_path)
model = model.eval()
# Initialize the model
ds_inference_kwargs = {"dtype": torch.bfloat16}
ds_inference_kwargs["tensor_parallel"] = {"tp_size": 8}
ds_inference_kwargs["enable_cuda_graph"] = args.use_hpu_graphs
# Make sure all devices/nodes have access to the model checkpoints
torch.distributed.barrier()
config = AutoConfig.from_pretrained(base_model_path, trust_remote_code=True)
ds_inference_kwargs["injection_policy"] = get_ds_injection_policy(config)
model = deepspeed.init_inference(model, **ds_inference_kwargs)
model = model.module
MODELS[base_model_path]["model"] = model
model = MODELS[base_model_path]["model"]
# Initialize the model
ds_inference_kwargs = {"dtype": torch.bfloat16}
ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size}
ds_inference_kwargs["enable_cuda_graph"] = args.use_hpu_graphs
# Make sure all devices/nodes have access to the model checkpoints
torch.distributed.barrier()
ds_inference_kwargs["injection_policy"] = get_ds_injection_policy(config)
model = deepspeed.init_inference(model, **ds_inference_kwargs)
MODELS[base_model_path]["model"] = model.module
else:
load_model(
base_model_path,
Expand All @@ -818,8 +819,9 @@ def main():
logger.info(f"device: {args.device}, n_hpu: {world_size}, bf16")

# warmup, the first time inference take longer because of graph compilation
start_time = time.time()
print("Warmup, Response: ")
if args.local_rank in [-1, 0]:
start_time = time.time()
print("Warmup, Response: ")
for new_text in predict_stream(
model_name=base_model_path,
device="hpu" if args.habana else "cpu",
Expand All @@ -837,15 +839,17 @@ def main():
):
if args.local_rank in [-1, 0]:
print(new_text, end="", flush=True)
logger.info(f"duration: {time.time() - start_time}")
if args.local_rank in [-1, 0]:
logger.info(f"duration: {time.time() - start_time}")

for idx, tp in enumerate(zip(prompts, args.instructions)):
prompt, instruction = tp
idxs = f"{idx+1}"
logger.info("=" * 30 + idxs + "=" * 30)
logger.info(f"Instruction: {instruction}")
start_time = time.time()
logger.info("Response: ")
if args.local_rank in [-1, 0]:
logger.info("=" * 30 + idxs + "=" * 30)
logger.info(f"Instruction: {instruction}")
start_time = time.time()
logger.info("Response: ")
for new_text in predict_stream(
model_name=base_model_path,
device="hpu" if args.habana else "cpu",
Expand All @@ -863,16 +867,18 @@ def main():
):
if args.local_rank in [-1, 0]:
print(new_text, end="", flush=True)
logger.info(f"duration: {time.time() - start_time}")
logger.info("=" * (60 + len(idxs)))
if args.local_rank in [-1, 0]:
logger.info(f"duration: {time.time() - start_time}")
logger.info("=" * (60 + len(idxs)))

for idx, tp in enumerate(zip(prompts, args.instructions)):
prompt, instruction = tp
idxs = f"{idx+1}"
logger.info("=" * 30 + idxs + "=" * 30)
logger.info(f"Instruction: {instruction}")
start_time = time.time()
logger.info("Response: ")
if args.local_rank in [-1, 0]:
logger.info("=" * 30 + idxs + "=" * 30)
logger.info(f"Instruction: {instruction}")
start_time = time.time()
logger.info("Response: ")
out = predict(
model_name=base_model_path,
device="hpu" if args.habana else "cpu",
Expand All @@ -889,9 +895,9 @@ def main():
num_return_sequences=args.num_return_sequences,
)
if args.local_rank in [-1, 0]:
print(f"nonstream out = {out}")
logger.info(f"duration: {time.time() - start_time}")
logger.info("=" * (60 + len(idxs)))
print(f"whole sentence out = {out}")
logger.info(f"duration: {time.time() - start_time}")
logger.info("=" * (60 + len(idxs)))


if __name__ == "__main__":
Expand Down

0 comments on commit 4675d42

Please sign in to comment.