From 411b9d1363f9589f5524597c6ae49bd8ed26cee0 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Mon, 18 Jul 2022 14:16:09 -0700 Subject: [PATCH] add OnDevice and remove zero-inference --- scripts/inference/bloom-ds-inference.py | 63 ++----------------------- 1 file changed, 4 insertions(+), 59 deletions(-) diff --git a/scripts/inference/bloom-ds-inference.py b/scripts/inference/bloom-ds-inference.py index c553750db..034e62450 100644 --- a/scripts/inference/bloom-ds-inference.py +++ b/scripts/inference/bloom-ds-inference.py @@ -41,12 +41,12 @@ parser.add_argument("--local_rank", required=False, type=int, help="used by dist launchers") parser.add_argument("--batch_size", default=1, type=int, help="batch size") parser.add_argument("--benchmark", action="store_true", help="additionally run benchmark") -parser.add_argument("--cpu_offload", action="store_true", help="whether to activate CPU offload") args = parser.parse_args() local_rank = int(os.getenv('LOCAL_RANK', '0')) world_size = int(os.getenv('WORLD_SIZE', '1')) +deepspeed.init_distributed('nccl') ### Model loading and instantiating on GPU (via ZeRO) @@ -132,44 +132,14 @@ def get_checkpoint_files(pretrained_model_name_or_path): else: dtype = torch.bfloat16 -#dtype = config.dtype -#print(dtype) - -model_hidden_size = config.hidden_size -train_batch_size = 1 * world_size - -ds_config = { - "fp16": { - "enabled": dtype == torch.float16, - }, - "bf16": { - "enabled": dtype == torch.bfloat16, - }, - "zero_optimization": { - "stage": 3, - "overlap_comm": True, - "contiguous_gradients": True, - "reduce_bucket_size": model_hidden_size * model_hidden_size, - "stage3_prefetch_bucket_size": 0.9 * model_hidden_size * model_hidden_size, - "stage3_param_persistence_threshold": 0 - }, - "steps_per_print": 2000, - "train_batch_size": train_batch_size, - "train_micro_batch_size_per_gpu": 1, - "wall_clock_breakdown": False -} - -if args.cpu_offload: - ds_config["zero_optimization"]["offload_param"] = dict(device="cpu", pin_memory=True) - -dschf = HfDeepSpeedConfig(ds_config) # this tells from_pretrained to instantiate directly on gpus - if args.benchmark: torch.cuda.empty_cache() gc.collect() deepspeed.runtime.utils.see_memory_usage('pre-from-pretrained', force=True) -model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16) +# Construct model with fake meta tensors, later will be replaced during ds-inference ckpt load +with deepspeed.OnDevice(dtype=dtype, device='meta'): + model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16) if args.benchmark: deepspeed.runtime.utils.see_memory_usage('post-from-pretrained', force=True) @@ -178,36 +148,11 @@ def get_checkpoint_files(pretrained_model_name_or_path): rank = dist.get_rank() -if rank == 0: - print(ds_config) - -ds_engine = deepspeed.initialize(model=model, config_params=ds_config)[0] -ds_engine.module.eval() -model = ds_engine.module - -### Deepspeed-ZeRO Unloading - -# a must to remove ZeRO-installed hooks! -ds_engine.destroy() - -# free GPU storage used by ZeRO -from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus -def ds_clear_params(ds_engine): - for p in ds_engine.parameters(): - if hasattr(p, "ds_tensor"): - p.ds_tensor = torch.empty(0, dtype=p.dtype, device=p.device) - p.ds_status = ZeroParamStatus.NOT_AVAILABLE - -ds_clear_params(ds_engine) -del ds_engine - if args.benchmark: torch.cuda.empty_cache() gc.collect() deepspeed.runtime.utils.see_memory_usage('post-init-ds-zero-init', force=True) - - ### Deepspeed-Inference Loading checkpoints_json = "checkpoints.json"