Skip to content

Commit

Permalink
Release swap buffers for persisted params (#2089)
Browse files Browse the repository at this point in the history
* Split parameter offload from z3

* Format fixes

* Bug fixes

* Cleanup

* Remove dead code

* Release swap buffers for persisted params

* Format fixes

* Format fixes

* Pass args correctly

* Use pinned memory for nvme offload

* Merge with masster

* Fix missing import

* model pesistence params

* Fix merge issues

* Handle none device

* Usse log_dist
  • Loading branch information
tjruwase committed Jul 31, 2022
1 parent a039e22 commit 2210ebe
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 52 deletions.
66 changes: 32 additions & 34 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,8 @@ def __init__(
monitor_memory=False,
)

if dist.get_rank() == 0:
logger.info(
f"DeepSpeed Flops Profiler Enabled: {self.flops_profiler_enabled()}")
log_dist(f"DeepSpeed Flops Profiler Enabled: {self.flops_profiler_enabled()}",
ranks=[0])

if self.flops_profiler_enabled():
self.flops_profiler = FlopsProfiler(self.module, self)
Expand Down Expand Up @@ -688,6 +687,9 @@ def zero_prefetch_bucket_size(self):
def zero_param_persistence_threshold(self):
return self._config.zero_config.param_persistence_threshold

def zero_model_persistence_threshold(self):
return self._config.zero_config.model_persistence_threshold

def zero_gather_16bit_weights_on_model_save(self):
return self._config.zero_config.gather_16bit_weights_on_model_save

Expand Down Expand Up @@ -779,18 +781,17 @@ def _configure_lr_scheduler(self, client_lr_scheduler):
# First check for scheduler in json configuration
lr_scheduler = self._scheduler_from_config(self.optimizer)
if lr_scheduler:
if self.global_rank == 0:
logger.info(
f"DeepSpeed using configured LR scheduler = {self.scheduler_name()}")
log_dist(
f"DeepSpeed using configured LR scheduler = {self.scheduler_name()}",
ranks=[0])
self.lr_scheduler = lr_scheduler
else:
if isinstance(client_lr_scheduler, Callable):
if self.global_rank == 0:
logger.info('DeepSpeed using client callable to create LR scheduler')
log_dist('DeepSpeed using client callable to create LR scheduler',
ranks=[0])
self.lr_scheduler = client_lr_scheduler(self.basic_optimizer)
else:
if self.global_rank == 0:
logger.info('DeepSpeed using client LR scheduler')
log_dist('DeepSpeed using client LR scheduler', ranks=[0])
self.lr_scheduler = client_lr_scheduler

log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0])
Expand Down Expand Up @@ -1093,31 +1094,26 @@ def _configure_optimizer(self, client_optimizer, model_parameters):
client_optimizer.param_groups[:] = [
pg for pg in client_optimizer.param_groups if len(pg["params"]) != 0
]
if self.global_rank == 0:
logger.info(
"Removing param_group that has no 'params' in the client Optimizer"
)
log_dist(
"Removing param_group that has no 'params' in the client Optimizer",
ranks=[0])

basic_optimizer = client_optimizer
if self.global_rank == 0:
logger.info('Using client Optimizer as basic optimizer')
log_dist('Using client Optimizer as basic optimizer', ranks=[0])
else:
basic_optimizer = client_optimizer(model_parameters)
if self.global_rank == 0:
logger.info('Using client callable to create basic optimizer')
log_dist('Using client callable to create basic optimizer', ranks=[0])
else:
basic_optimizer = self._configure_basic_optimizer(model_parameters)
if self.global_rank == 0:
logger.info(
"Using DeepSpeed Optimizer param name {} as basic optimizer".format(
self.optimizer_name()))
log_dist(
f"Using DeepSpeed Optimizer param name {self.optimizer_name()} as basic optimizer",
ranks=[0])

self._check_for_duplicates(basic_optimizer)

self.basic_optimizer = basic_optimizer
if self.global_rank == 0:
logger.info("DeepSpeed Basic Optimizer = {}".format(
basic_optimizer.__class__.__name__))
log_dist("DeepSpeed Basic Optimizer = {basic_optimizer.__class__.__name__}",
ranks=[0])

if self.zero_optimization():
assert (
Expand All @@ -1138,8 +1134,7 @@ def _configure_optimizer(self, client_optimizer, model_parameters):
elif self.amp_enabled():
assert not (self.fp16_enabled() or self.bfloat16_enabled()), "Cannot enable both amp with (legacy) fp16 or bfloat16 mode"
amp_params = self.amp_params()
if self.global_rank == 0:
logger.info(f"Initializing AMP with these params: {amp_params}")
log_dist(f"Initializing AMP with these params: {amp_params}", ranks=[0])
try:
logger.info("Initializing Apex amp from: {}".format(amp.__path__))
except NameError:
Expand Down Expand Up @@ -1340,8 +1335,8 @@ def _configure_bf16_optimizer(self, optimizer):
if optimizer is None:
optimizer = DummyOptim(list(self.module.parameters()))

if self.global_rank == 0:
logger.info('Creating unfused BF16 optimizer')
log_dist('Creating BF16 optimizer', ranks=[0])

timers = self.timers if self.wall_clock_breakdown() else None
optimizer = BF16_Optimizer(
optimizer,
Expand All @@ -1356,7 +1351,6 @@ def _configure_bf16_optimizer(self, optimizer):

def _configure_zero_optimizer(self, optimizer):
zero_stage = self.zero_optimization_stage()
log_dist('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage), ranks=[0])
assert self.communication_data_type in (torch.float16, torch.bfloat16), "ZeRO supports only 'communication_data_type': ['fp16', 'bfp16']"
timers = self.timers if self.wall_clock_breakdown() else None

Expand All @@ -1374,6 +1368,8 @@ def _configure_zero_optimizer(self, optimizer):
round_robin_gradients = self.zero_round_robin_gradients()
assert not isinstance(optimizer, DummyOptim), "zero stage 2 requires an optimizer"

log_dist('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage),
ranks=[0])
# Overlap and contiguous grads are meaningless in stage 1 and are ignored
if zero_stage == ZeroStageEnum.optimizer_states:
overlap_comm = False
Expand Down Expand Up @@ -1419,10 +1415,8 @@ def _configure_zero_optimizer(self, optimizer):

elif zero_stage == ZeroStageEnum.weights:
assert not self.has_moe_layers, "MoE not supported with Stage 3"
logger.info("Initializing ZeRO Stage 3") if dist.get_rank() == 0 else None
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3

if isinstance(optimizer, DummyOptim):
log_dist("Creating ZeRO Offload", ranks=[0])
optimizer = DeepSpeedZeRoOffload(
self.module,
timers=timers,
Expand All @@ -1432,10 +1426,13 @@ def _configure_zero_optimizer(self, optimizer):
max_reuse_distance=self.zero_max_reuse_distance(),
max_live_parameters=self.zero_max_live_parameters(),
param_persistence_threshold=self.zero_param_persistence_threshold(),
model_persistence_threshold=self.zero_model_persistence_threshold(),
offload_param_config=self.zero_offload_param(),
mpu=self.mpu)
else:

log_dist('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage),
ranks=[0])
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
optimizer = DeepSpeedZeroOptimizer_Stage3(
self.module,
optimizer,
Expand All @@ -1451,6 +1448,7 @@ def _configure_zero_optimizer(self, optimizer):
max_reuse_distance=self.zero_max_reuse_distance(),
max_live_parameters=self.zero_max_live_parameters(),
param_persistence_threshold=self.zero_param_persistence_threshold(),
model_persistence_threshold=self.zero_model_persistence_threshold(),
dp_process_group=self.data_parallel_group,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=self.zero_overlap_comm(),
Expand Down
4 changes: 4 additions & 0 deletions deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

from pydantic import Field, validator
import sys
from typing import Optional
from enum import Enum
from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigModel
Expand Down Expand Up @@ -114,6 +115,9 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
param_persistence_threshold: int = Field(1e5,
ge=0,
alias="stage3_param_persistence_threshold")
model_persistence_threshold: int = Field(sys.maxsize,
ge=0,
alias="stage3_model_persistence_threshold")
max_live_parameters: int = Field(1e9, ge=0, alias="stage3_max_live_parameters")
max_reuse_distance: int = Field(1e9, ge=0, alias="stage3_max_reuse_distance")
gather_16bit_weights_on_model_save: bool = Field(
Expand Down
22 changes: 16 additions & 6 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Licensed under the MIT license.
"""

import sys
import torch
from torch.cuda import Stream
from collections import OrderedDict
Expand Down Expand Up @@ -173,10 +174,11 @@ def __init__(self,
max_reuse_distance=1000000000,
max_live_parameters=1000000000,
param_persistence_threshold=100000,
model_persistence_threshold=sys.maxsize,
offload_param_config=None,
mpu=None):

see_memory_usage("TensorOffload initialize beginning", force=True)
see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=True)

print_rank_0(f"initialized {__class__.__name__} with args: {locals()}",
force=False)
Expand All @@ -196,8 +198,11 @@ def __init__(self,

_inject_parameters(module, ZeROOrderedDict)

self.persistence_threshold = int(param_persistence_threshold)
self.persistent_parameters = self.mark_persistent_parameters()
self.param_numel_persistence_threshold = int(param_persistence_threshold)
self.model_persistence_threshold = int(model_persistence_threshold)
self.persistent_parameters = self.mark_persistent_parameters(
self.param_numel_persistence_threshold,
self.model_persistence_threshold)

self.param_coordinators = {}
self._prefetch_bucket_sz = int(prefetch_bucket_size)
Expand All @@ -213,6 +218,8 @@ def __init__(self,
f'Created module hooks: forward = {len(self.forward_hooks)}, backward = {len(self.backward_hooks)}',
force=False)

see_memory_usage("DeepSpeedZeRoOffload initialize [end]", force=True)

@instrument_w_nvtx
def partition_all_parameters(self):
"""Partitioning Parameters that were not partitioned usually if parameters
Expand Down Expand Up @@ -291,20 +298,23 @@ def _end_of_forward_hook(module, *args):
global FWD_MODULE_STACK
FWD_MODULE_STACK.append(self.module)

def mark_persistent_parameters(self):
def mark_persistent_parameters(self, param_threshold, model_threshold):
persistent_params = []
total_persistent_parameters = 0
params_count = 0
for _, param in self.module.named_parameters(recurse=True):
if param.ds_numel < self.persistence_threshold:
if param.ds_numel + total_persistent_parameters > model_threshold:
continue

if param.ds_numel < param_threshold:
params_count += 1
param.ds_persist = True
persistent_params.append(param)
total_persistent_parameters += param.ds_numel

print_rank_0(
f"Parameter Offload: Total persistent parameters: {total_persistent_parameters} in {params_count} params",
force=False)
force=True)

return persistent_params

Expand Down
10 changes: 7 additions & 3 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,9 +683,13 @@ def get_model():

# Remote device is the device where parameter partitions are stored
# It can be same as local_device or it could be CPU or NVMe.
self.remote_device = self.local_device if remote_device is None else remote_device
self.pin_memory = pin_memory if (self.remote_device
== OffloadDeviceEnum.cpu) else False
self.remote_device = self.local_device if remote_device in [
None,
OffloadDeviceEnum.none
] else remote_device
self.pin_memory = pin_memory if (
self.remote_device in [OffloadDeviceEnum.cpu,
OffloadDeviceEnum.nvme]) else False

# Enable fp16 param swapping to NVMe
if self.remote_device == OffloadDeviceEnum.nvme:
Expand Down
10 changes: 10 additions & 0 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,16 @@ def __all_gather_params(self, params: Set[Parameter]) -> None:
assert param.ds_status == ZeroParamStatus.INFLIGHT, param.ds_summary()
self.__inflight_param_registry[param] = handle

# Release swap buffers for persisted params on nvme since they will never be partitioned or evicted from GPU
swap_persisted_params = [
p for p in partitioned_params
if p.ds_persist and p.ds_tensor.final_location == OffloadDeviceEnum.nvme
]
if swap_persisted_params:
swap_persisted_params[
0].nvme_swapper.remove_partition_and_release_buffers(
swap_persisted_params)

@instrument_w_nvtx
def __release_param(self, param: Parameter) -> None:
if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules:
Expand Down
23 changes: 14 additions & 9 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Licensed under the MIT license.
"""

import sys
import gc
import collections
from typing import Deque, Dict, Tuple
Expand Down Expand Up @@ -88,6 +89,7 @@ def __init__(self,
max_reuse_distance=1000000000,
max_live_parameters=1000000000,
param_persistence_threshold=100000,
model_persistence_threshold=sys.maxsize,
dp_process_group=None,
reduce_scatter=True,
overlap_comm=False,
Expand Down Expand Up @@ -146,15 +148,18 @@ def __init__(self,
self.params_in_nvme_and_cpu = False
self.max_params_in_cpu = 0

self.parameter_offload = DeepSpeedZeRoOffload(module,
timers,
ds_config,
overlap_comm,
prefetch_bucket_size,
max_reuse_distance,
max_live_parameters,
param_persistence_threshold,
offload_param_config)
self.parameter_offload = DeepSpeedZeRoOffload(
module=module,
timers=timers,
ds_config=ds_config,
overlap_comm=overlap_comm,
prefetch_bucket_size=prefetch_bucket_size,
max_reuse_distance=max_reuse_distance,
max_live_parameters=max_live_parameters,
param_persistence_threshold=param_persistence_threshold,
model_persistence_threshold=model_persistence_threshold,
offload_param_config=offload_optimizer_config)

self.persistent_parameters = self.parameter_offload.persistent_parameters
self._configure_offloading(offload_optimizer_config, offload_param_config)

Expand Down

0 comments on commit 2210ebe

Please sign in to comment.