Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add OnDevice and remove zero-inference #316

Merged
merged 1 commit into from
Jul 19, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
63 changes: 4 additions & 59 deletions scripts/inference/bloom-ds-inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@
parser.add_argument("--local_rank", required=False, type=int, help="used by dist launchers")
parser.add_argument("--batch_size", default=1, type=int, help="batch size")
parser.add_argument("--benchmark", action="store_true", help="additionally run benchmark")
parser.add_argument("--cpu_offload", action="store_true", help="whether to activate CPU offload")
args = parser.parse_args()

local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))

deepspeed.init_distributed('nccl')
stas00 marked this conversation as resolved.
Show resolved Hide resolved

### Model loading and instantiating on GPU (via ZeRO)

Expand Down Expand Up @@ -132,44 +132,14 @@ def get_checkpoint_files(pretrained_model_name_or_path):
else:
dtype = torch.bfloat16

#dtype = config.dtype
#print(dtype)

model_hidden_size = config.hidden_size
train_batch_size = 1 * world_size

ds_config = {
"fp16": {
"enabled": dtype == torch.float16,
},
"bf16": {
"enabled": dtype == torch.bfloat16,
},
"zero_optimization": {
"stage": 3,
"overlap_comm": True,
"contiguous_gradients": True,
"reduce_bucket_size": model_hidden_size * model_hidden_size,
"stage3_prefetch_bucket_size": 0.9 * model_hidden_size * model_hidden_size,
"stage3_param_persistence_threshold": 0
},
"steps_per_print": 2000,
"train_batch_size": train_batch_size,
"train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": False
}

if args.cpu_offload:
ds_config["zero_optimization"]["offload_param"] = dict(device="cpu", pin_memory=True)

dschf = HfDeepSpeedConfig(ds_config) # this tells from_pretrained to instantiate directly on gpus

if args.benchmark:
torch.cuda.empty_cache()
gc.collect()
deepspeed.runtime.utils.see_memory_usage('pre-from-pretrained', force=True)

model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16)
# Construct model with fake meta tensors, later will be replaced during ds-inference ckpt load
with deepspeed.OnDevice(dtype=dtype, device='meta'):
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16)

if args.benchmark:
deepspeed.runtime.utils.see_memory_usage('post-from-pretrained', force=True)
Expand All @@ -178,36 +148,11 @@ def get_checkpoint_files(pretrained_model_name_or_path):

rank = dist.get_rank()

if rank == 0:
print(ds_config)

ds_engine = deepspeed.initialize(model=model, config_params=ds_config)[0]
ds_engine.module.eval()
model = ds_engine.module

### Deepspeed-ZeRO Unloading

# a must to remove ZeRO-installed hooks!
ds_engine.destroy()

# free GPU storage used by ZeRO
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
def ds_clear_params(ds_engine):
for p in ds_engine.parameters():
if hasattr(p, "ds_tensor"):
p.ds_tensor = torch.empty(0, dtype=p.dtype, device=p.device)
p.ds_status = ZeroParamStatus.NOT_AVAILABLE

ds_clear_params(ds_engine)
del ds_engine

if args.benchmark:
torch.cuda.empty_cache()
gc.collect()
deepspeed.runtime.utils.see_memory_usage('post-init-ds-zero-init', force=True)



### Deepspeed-Inference Loading

checkpoints_json = "checkpoints.json"
Expand Down