From 4e9a045a7b107a3a5317bb197d2c21835d2bd617 Mon Sep 17 00:00:00 2001 From: Alex Kogan <82225080+sakogan@users.noreply.github.com> Date: Tue, 3 Oct 2023 20:00:52 -0400 Subject: [PATCH] Fix a bug in DeepSpeedMLP (#4389) Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/ops/transformer/inference/ds_mlp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepspeed/ops/transformer/inference/ds_mlp.py b/deepspeed/ops/transformer/inference/ds_mlp.py index b6638f98a0ea..36de06db920f 100644 --- a/deepspeed/ops/transformer/inference/ds_mlp.py +++ b/deepspeed/ops/transformer/inference/ds_mlp.py @@ -20,8 +20,8 @@ def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count self.config = config - data_type = torch.half if self.config.dtype == torch.int8 else self.config.dtype - data_type_fp = data_type + data_type = torch.int8 if self.config.dtype == torch.int8 else self.config.dtype + data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype device = get_accelerator().current_device_name() proj_factor = 2 if self.config.mlp_act_func_type in GATED_ACTIVATION_TYPES else 1