Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def _enforce_cpu_offload():
# not sure why apex was cloning the weights before flattening
# removing cloning here

# Compute group size for VRAM check (need 2x model size on GPU to flatten in place: params + flat copy)
# Compute group size for memory check (need 2x model size on accelerator to flatten in place: params + flat copy)
orig_group_numel = sum(param.numel() for param in self.bit16_groups[i])
alignment = self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_dp_process_group[i])
aligned_numel = int(math.ceil(orig_group_numel / alignment)) * alignment
Expand All @@ -378,13 +378,13 @@ def _enforce_cpu_offload():

empty_cache()
accelerator = get_accelerator()
available_vram = accelerator.available_memory() if accelerator.is_available() else 0
# Flatten on GPU only if we have enough VRAM for the flat buffer (2x = params already there + copy)
flatten_on_gpu = (accelerator.is_available() and (available_vram >= flat_buffer_bytes))
available_memory = accelerator.available_memory() if accelerator.is_available() else 0
# Flatten on accelerator device if we have enough memory for the flat buffer
flatten_on_accelerator = (accelerator.is_available() and (available_memory >= flat_buffer_bytes))

if not flatten_on_gpu:
if not flatten_on_accelerator:
see_memory_usage(f"Before moving param group {i} to CPU")
# move all the parameters to cpu to free up GPU space for creating flat buffer
# move all the parameters to cpu to free up accelerator memory for creating flat buffer
for param in self.bit16_groups[i]:
param.cpu_data = param.data.cpu()
param.data = torch.empty(1).to(param.device)
Expand All @@ -409,21 +409,21 @@ def _enforce_cpu_offload():
# Create meta tensors list, ordered according to round_robin_tensors
meta_tensors = []
for param in round_robin_tensors:
if flatten_on_gpu:
if flatten_on_accelerator:
meta_tensors.append(torch.zeros_like(param.data, device="meta"))
else:
meta_tensors.append(torch.zeros_like(param.cpu_data, device="meta"))
self.round_robin_bit16_meta.append(meta_tensors)

if flatten_on_gpu:
logger.info(f"Flattening param group {i} on GPU (sufficient VRAM)")
if flatten_on_accelerator:
logger.info(f"Flattening param group {i} on {accelerator.device_name()} (sufficient memory)")
flattened_buffer = self.flatten_dense_tensors_aligned(self.round_robin_bit16_groups[i],
alignment,
use_cpu_data=False)
use_cpu_data=False).detach()
self.bit16_groups_flat.append(flattened_buffer)
see_memory_usage(f"After flattening param group {i} on GPU", force=False)
see_memory_usage(f"After flattening param group {i} on {accelerator.device_name()}", force=False)
else:
logger.info(f"Flattening param group {i} on CPU (insufficient VRAM)")
logger.info(f"Flattening param group {i} on CPU (insufficient memory)")

flattened_buffer = self.flatten_dense_tensors_aligned(self.round_robin_bit16_groups[i],
alignment,
Expand All @@ -437,7 +437,8 @@ def _enforce_cpu_offload():
self.bit16_groups_flat.append(flattened_buffer.to(get_accelerator().current_device_name()))
del flattened_buffer

see_memory_usage(f"After flattening and moving param group {i} to GPU", force=False)
see_memory_usage(f"After flattening and moving param group {i} to {get_accelerator().device_name()}",
force=False)

if dist.get_rank(group=self.real_dp_process_group[i]) == 0:
see_memory_usage(f"After Flattening and after emptying param group {i} cache", force=False)
Expand Down
54 changes: 49 additions & 5 deletions tests/unit/v1/zero/test_stage2_flatten_on_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
"""

import pytest
import torch
import deepspeed
from deepspeed.accelerator import get_accelerator
from deepspeed.utils import set_log_level_from_string
from unit.common import DistributedTest
from unit.simple_model import SimpleModel
from unit.simple_model import SimpleModel, random_dataloader

_DTYPE_MAP = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}


def _apply_dtype_to_config(config_dict, dtype):
Expand Down Expand Up @@ -70,10 +73,10 @@ def mock_logger_info(msg, *args, **kwargs):
model_parameters=model.parameters(),
)

# Small model + no CPU offload => GPU path; that path logs "on GPU"
gpu_path_logs = [m for m in log_messages if "Flattening param group" in m and "on GPU" in m]
assert gpu_path_logs, (
f"Expected GPU flatten path (logger.info should be called with 'Flattening param group' and 'on GPU'). "
# Small model + no CPU offload => accelerator path logs "Flattening param group ... (sufficient memory)"
accel_path_logs = [m for m in log_messages if "Flattening param group" in m and "(sufficient memory)" in m]
assert accel_path_logs, (
f"Expected accelerator flatten path (log should contain 'Flattening param group' and '(sufficient memory)'). "
f"Captured messages: {log_messages}")

def test_flat_buffers_on_accelerator(self, zero_stage, dtype):
Expand Down Expand Up @@ -107,3 +110,44 @@ def test_flat_buffers_on_accelerator(self, zero_stage, dtype):
device_type = get_accelerator().device_name()
for i, flat in enumerate(opt.bit16_groups_flat):
assert flat.device.type == device_type, (f"Flat buffer {i} must be on {device_type}, got {flat.device}")

@pytest.mark.world_size(1)
def test_flatten_on_accelerator_training_step(self, zero_stage, dtype):
"""Regression: flat buffer must be detached so inplace ops during step don't crash."""
if not get_accelerator().is_available():
pytest.skip("Accelerator not available")
config_dict = {
"train_micro_batch_size_per_gpu": 2,
"gradient_accumulation_steps": 1,
"zero_optimization": {
"stage": zero_stage
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-3
}
},
}
_apply_dtype_to_config(config_dict, dtype)

hidden_dim = 64
model = SimpleModel(hidden_dim=hidden_dim, nlayers=2)
engine, _, _, _ = deepspeed.initialize(
config=config_dict,
model=model,
model_parameters=model.parameters(),
)
for flat in engine.optimizer.bit16_groups_flat:
assert flat.grad_fn is None, ("Flat buffer must be detached from autograd graph"
" to prevent inplace-modification errors during optimizer step")

data_loader = random_dataloader(model=engine,
total_samples=8,
hidden_dim=hidden_dim,
device=engine.device,
dtype=_DTYPE_MAP[dtype])
for batch in data_loader:
loss = engine(batch[0], batch[1])
engine.backward(loss)
engine.step()
Loading