Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[kernel] support pure fp16 for cpu adam and update gemini optim tests #4921

Merged
merged 2 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 94 additions & 107 deletions colossalai/kernel/cuda_native/csrc/cpu_adam.cpp

Large diffs are not rendered by default.

41 changes: 30 additions & 11 deletions colossalai/kernel/cuda_native/csrc/cpu_adam.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ SOFTWARE
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_LOAD_HALF(x) \
_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
#define SIMD_STORE_HALF(x, d) \
_mm256_store_ps( \
x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
#define SIMD_STORE_HALF(x, d) \
_mm256_storeu_ps((float *)(x), _mm256_castsi256_ps(_mm512_cvtps_ph( \
d, _MM_FROUND_TO_NEAREST_INT)))

#elif defined(__AVX256__) or defined(__AVX2__)
#define SIMD_WIDTH 8
Expand All @@ -66,9 +66,9 @@ SOFTWARE
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
#define SIMD_STORE_HALF(x, d) \
_mm_store_ps( \
x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
#define SIMD_STORE_HALF(x, d) \
_mm_storeu_ps((float *)(x), _mm_castsi128_ps(_mm256_cvtps_ph( \
d, _MM_FROUND_TO_NEAREST_INT)))

#endif

Expand All @@ -83,11 +83,12 @@ union AVX_Data {

#endif

#define STEP(SPAN) \
void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \
float *_exp_avg_sq, size_t _param_size, \
bool param_half_precision = false, \
bool grad_half_precision = false, float loss_scale = -1);
#define STEP(SPAN) \
void Step_##SPAN( \
float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, \
size_t _param_size, bool param_half_precision = false, \
bool grad_half_precision = false, bool momentum_half_precision = false, \
bool variance_half_precision = false, float loss_scale = -1);

class Adam_Optimizer {
public:
Expand Down Expand Up @@ -141,6 +142,24 @@ class Adam_Optimizer {
}
}

inline void simd_load(bool is_half, float *ptr, __half *h_ptr,
AVX_Data &data) {
if (is_half) {
data.data = SIMD_LOAD_HALF(h_ptr);
} else {
data.data = SIMD_LOAD(ptr);
}
}

inline void simd_store(bool is_half, float *ptr, __half *h_ptr,
AVX_Data &data) {
if (is_half) {
SIMD_STORE_HALF(h_ptr, data.data);
} else {
SIMD_STORE(ptr, data.data);
}
}

void step(size_t step, float lr, float beta1, float beta2, float epsilon,
float weight_decay, bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,
Expand Down
3 changes: 1 addition & 2 deletions colossalai/nn/optimizer/cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ def step(self, closure=None, div_scale: float = -1):
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
self._pre_update(p, "exp_avg", "exp_avg_sq")
# FIXME(ver217): CPU adam kernel only supports fp32 states now
if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float:
if p.grad.dtype is torch.bfloat16:
# cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
Expand Down
3 changes: 1 addition & 2 deletions colossalai/nn/optimizer/hybrid_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ def step(self, closure=None, div_scale: float = -1):
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
self._pre_update(p, "exp_avg", "exp_avg_sq")
# FIXME(ver217): CPU adam kernel only supports fp32 states now
if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float:
if p.grad.dtype is torch.bfloat16:
# cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
Expand Down
7 changes: 2 additions & 5 deletions tests/test_optimizer/test_adam_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,14 @@
_FUSED_ALLOWED_P_G_TYPES = [
(torch.float, torch.half),
(torch.float, torch.float),
(torch.half, torch.float),
(torch.half, torch.half),
(torch.bfloat16, torch.float),
(torch.float, torch.bfloat16),
(torch.bfloat16, torch.bfloat16),
]

_CPU_ALLOWED_P_G_TYPES = [
(torch.float, torch.half),
(torch.float, torch.float),
(torch.half, torch.float),
(torch.half, torch.half),
]

Expand Down Expand Up @@ -138,8 +135,8 @@ def check_adam_kernel(
master_exp_avg_sq = torch.zeros_like(master_p)
p = master_p.clone().to(p_dtype)
g = master_g.clone().to(g_dtype)
exp_avg = master_exp_avg.clone()
exp_avg_sq = master_exp_avg_sq.clone()
exp_avg = master_exp_avg.clone().to(p_dtype)
exp_avg_sq = master_exp_avg_sq.clone().to(p_dtype)

for step in range(1, 1 + n_steps):
torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq)
Expand Down
2 changes: 0 additions & 2 deletions tests/test_optimizer/test_adam_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
(torch.float, torch.float), # pure fp32
(torch.float, torch.half), # fp16 amp
(torch.float, torch.bfloat16), # bfloat16 amp
# (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16
# (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16
]

N_STEPS = 3
Expand Down
12 changes: 9 additions & 3 deletions tests/test_zero/test_gemini/test_grad_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module):

@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", ["gpt2"])
def exam_grad_clipping(placement_config, model_name: str):
@parameterize("master_weights", [True, False])
def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
set_seed(1912)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
Expand Down Expand Up @@ -82,6 +83,7 @@ def exam_grad_clipping(placement_config, model_name: str):
chunk_config_dict=config_dict,
chunk_init_device=init_device,
pin_memory=True,
master_weights=master_weights,
**placement_config,
)

Expand All @@ -103,15 +105,19 @@ def exam_grad_clipping(placement_config, model_name: str):

torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim)
loss = run_fwd_bwd(model, data, label, criterion, zero_optim)
assert_close(torch_loss, loss)

# as no master weights leads to error accumulation, we don't check the loss
if master_weights:
assert_close(torch_loss, loss)

import apex.amp as apex_amp

torch.nn.utils.clip_grad_norm_(apex_amp.master_params(torch_optim), 1.0)
torch_optim.step()
zero_optim.step()

check_param(model, torch_model)
if master_weights:
check_param(model, torch_model)


def run_dist(rank, world_size, port):
Expand Down
15 changes: 11 additions & 4 deletions tests/test_zero/test_gemini/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,14 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", TEST_MODELS)
@parameterize("mixed_precision", [torch.half, torch.bfloat16])
def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype):
@parameterize("master_weights", [True, False])
def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

torch_model = model_builder().cuda()
# apex no master weights leads to nan, so we don't use it
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
Expand All @@ -90,7 +92,9 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]["chunk_size"] = 5000
config_dict[world_size]["keep_gathered"] = False
model = GeminiDDP(model, config_dict, **placement_config, mixed_precision=mixed_precision)
model = GeminiDDP(
model, config_dict, **placement_config, mixed_precision=mixed_precision, master_weights=master_weights
)

optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128)
Expand All @@ -109,12 +113,15 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt

torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
assert_close(torch_loss, loss, rtol=rtol, atol=atol)
# as no master weights leads to error accumulation, we don't check the loss
if master_weights:
assert_close(torch_loss, loss, rtol=rtol, atol=atol)

zero_optim.step()
torch_optim.step()

check_param(model, torch_model, mixed_precision)
if master_weights:
check_param(model, torch_model, mixed_precision)


@parameterize("placement_config", PLACEMENT_CONFIGS)
Expand Down
Loading