Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion benchmarks/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ def main(

linear_float8 = Float8Linear.from_float(
copy.deepcopy(linear_ref),
emulate=False,
config=config,
)
scaling_repr = linear_float8.scaling_repr()
Expand Down
1 change: 0 additions & 1 deletion benchmarks/bench_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def get_model(K, N, is_fp8, base_dtype=torch.float32):
if is_fp8:
swap_linear_with_float8_linear(
m,
emulate=False,
config=config,
)
return m
Expand Down
3 changes: 3 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ class Float8LinearConfig:
# This can cause a memory spike however so we keep this off by default.
pad_inner_dim: bool = False

# If True, emulation is used instead of hardware accelerated gemm
emulate: bool = False


# If True, use 'fnuz' float8 types for calculations.
# Currently, ROCm only supports fnuz variants.
Expand Down
5 changes: 1 addition & 4 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def __init__(self, *args, **kwargs):
)
# Amax scales should always be kept as float32.
self.always_float32_buffers = set()
emulate = kwargs.pop("emulate", False)
config = kwargs.pop("config")
emulate = config.emulate
super().__init__(*args, **kwargs)

# Defines the scaling behavior of input, weight, grad_output
Expand Down Expand Up @@ -434,15 +434,13 @@ def extra_repr(self):
def from_float(
cls,
mod,
emulate: bool = False,
config: Optional[Float8LinearConfig] = None,
):
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear

Args:
mod (torch.nn.Linear): nn.Linear to convert
emulate (bool): whether to emulate fp8 matmul logic in float32
config (Optional[Float8LinearConfig]): configuration for conversion to float8
"""
if config is None:
Expand All @@ -452,7 +450,6 @@ def from_float(
mod.in_features,
mod.out_features,
bias=False,
emulate=emulate,
config=config,
)
new_mod.weight = mod.weight
Expand Down
3 changes: 0 additions & 3 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def post_order_traversal(
def swap_linear_with_float8_linear(
module: nn.Module,
*,
emulate: bool = False,
module_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
config: Float8LinearConfig = None,
) -> Optional[nn.Module]:
Expand All @@ -136,7 +135,6 @@ def swap_linear_with_float8_linear(

Args:
module: Module to modify.
emulate: If True, emulation is used instead of hardware accelerated gemm
module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
that pass the filter function will be swapped. The inputs to the
filter function are the FQN and module instance.
Expand All @@ -149,7 +147,6 @@ def swap_linear_with_float8_linear(
config = Float8LinearConfig()
from_float = lambda m: Float8Linear.from_float(
m,
emulate=emulate,
config=config,
)
return swap_linear_layers(
Expand Down
28 changes: 16 additions & 12 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,10 @@ def _test_linear_impl(
self,
x,
m_ref,
emulate: bool,
config: Float8LinearConfig,
):
m_fp8 = Float8Linear.from_float(
copy.deepcopy(m_ref),
emulate,
config,
)
for _ in range(2):
Expand Down Expand Up @@ -264,11 +262,11 @@ def test_linear(
cast_config_grad_output=Float8TensorCastConfig(
scaling_type=scaling_type_grad_output
),
emulate=emulate,
)
self._test_linear_impl(
x,
m_ref,
emulate,
config,
)

Expand Down Expand Up @@ -303,8 +301,9 @@ def test_autocast_outputs(
cast_config_grad_output=Float8TensorCastConfig(
scaling_type=TensorScalingType.DELAYED
),
emulate=emulate,
)
m = Float8Linear.from_float(copy.deepcopy(m_ref), emulate, config)
m = Float8Linear.from_float(copy.deepcopy(m_ref), config)

# autocast off
x = torch.randn(16, 32, device="cuda", dtype=linear_dtype)
Expand Down Expand Up @@ -339,8 +338,8 @@ def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool):
)

m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
config = Float8LinearConfig()
m = Float8Linear.from_float(copy.deepcopy(m), emulate, config)
config = Float8LinearConfig(emulate=emulate)
m = Float8Linear.from_float(copy.deepcopy(m), config)

# Cast the module to dtype
m = m.to(dtype=linear_dtype)
Expand Down Expand Up @@ -389,10 +388,10 @@ def test_repr(self):
cast_config_weight=Float8TensorCastConfig(
scaling_type=TensorScalingType.DELAYED
),
emulate=True,
)
m = Float8Linear.from_float(
copy.deepcopy(m),
emulate=True,
config=config,
)
s = m.__repr__()
Expand Down Expand Up @@ -604,7 +603,8 @@ class TestFloat8LinearUtils(unittest.TestCase):
def test_swap_root_linear(self):
for emulate in [True, False]:
module = nn.Linear(3, 3)
module = swap_linear_with_float8_linear(module, emulate=emulate)
config = Float8LinearConfig(emulate=emulate)
module = swap_linear_with_float8_linear(module, config=config)
self.assertIsInstance(module, Float8Linear)
self.assertEqual(module.linear_mm_config.y.emulate, emulate)
self.assertEqual(module.linear_mm_config.y.emulate, emulate)
Expand All @@ -613,11 +613,12 @@ def test_swap_root_linear_with_children_raises(self):
for emulate in [True, False]:
module = nn.Linear(3, 3)
module.child = nn.Sequential(nn.Linear(3, 3))
config = Float8LinearConfig(emulate=emulate)
with self.assertRaisesRegex(
AssertionError,
"Does not support a root nn.Linear with children",
):
swap_linear_with_float8_linear(module, emulate=emulate)
swap_linear_with_float8_linear(module, config=config)

def test_swap_submodule_linears(self):
class MLP(nn.Module):
Expand All @@ -628,7 +629,8 @@ def __init__(self, dim: int):

for emulate in [True, False]:
model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3))
model = swap_linear_with_float8_linear(model, emulate=emulate)
config = Float8LinearConfig(emulate=emulate)
model = swap_linear_with_float8_linear(model, config=config)
self.assertIsInstance(model[0].lin1, Float8Linear)
self.assertIsInstance(model[0].lin2, Float8Linear)
self.assertIsInstance(model[1], Float8Linear)
Expand All @@ -655,9 +657,10 @@ def module_filter_fn(fqn, mod):
and mod.out_features % 16 == 0
)

config = Float8LinearConfig(emulate=True)
model = swap_linear_with_float8_linear(
model,
emulate=True,
config=config,
module_filter_fn=module_filter_fn,
)
# in_features=8, out_features=32, 8 is less than 32.
Expand All @@ -683,9 +686,10 @@ def __init__(self, dim: int):
"0.lin2",
"2.lin1",
]
config = Float8LinearConfig(emulate=True)
model = swap_linear_with_float8_linear(
model,
emulate=True,
config=config,
module_filter_fn=module_filter_fn,
)
self.assertTrue(type(model[0].lin1) is Float8Linear)
Expand Down
23 changes: 15 additions & 8 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
def _test_compile_base(
backend: str,
fullgraph: bool,
emulate: bool,
config: Float8LinearConfig,
dtype: torch.dtype,
):
Expand All @@ -50,7 +49,6 @@ def _test_compile_base(

m_fp8 = Float8Linear.from_float(
copy.deepcopy(m_ref),
emulate,
config,
)

Expand Down Expand Up @@ -95,11 +93,11 @@ def test_eager_only(
cast_config_grad_output=Float8TensorCastConfig(
scaling_type=scaling_type_grad_output
),
emulate=emulate,
)
_test_compile_base(
"eager",
fullgraph,
emulate,
config,
dtype,
)
Expand Down Expand Up @@ -133,11 +131,11 @@ def test_aot_eager(
cast_config_grad_output=Float8TensorCastConfig(
scaling_type=scaling_type_grad_output
),
emulate=emulate,
)
_test_compile_base(
"aot_eager",
fullgraph,
emulate,
config,
dtype,
)
Expand Down Expand Up @@ -171,11 +169,11 @@ def test_inductor(
cast_config_grad_output=Float8TensorCastConfig(
scaling_type=scaling_type_grad_output
),
emulate=emulate,
)
_test_compile_base(
"inductor",
fullgraph,
emulate,
config,
dtype,
)
Expand Down Expand Up @@ -315,11 +313,20 @@ def test_sync_amax_func_cuda_graph_success():
my_module = nn.Sequential(
nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True)
).to("cuda")
config = Float8LinearConfig(
cast_config_input=Float8TensorCastConfig(
scaling_type=TensorScalingType.DELAYED
),
cast_config_weight=Float8TensorCastConfig(
scaling_type=TensorScalingType.DELAYED
),
cast_config_grad_output=Float8TensorCastConfig(
scaling_type=TensorScalingType.DELAYED
),
)
swap_linear_with_float8_linear(
my_module,
scaling_type_input=TensorScalingType.DELAYED,
scaling_type_weight=TensorScalingType.DELAYED,
scaling_type_grad_output=TensorScalingType.DELAYED,
config=config,
)
inpt = torch.randn(
16, 16, device="cuda", dtype=torch.float32, requires_grad=True
Expand Down
11 changes: 6 additions & 5 deletions test/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from float8_experimental.config import TensorScalingType
from float8_experimental import Float8LinearConfig

from float8_experimental.float8_dynamic_utils import NoopFwToFloat8E5M2Bw
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
Expand Down Expand Up @@ -184,14 +184,15 @@ def _test_fp8_mlp_tensor_parallelism_base(
# For now, only supports dynamic scaling of `x` and `dL_dY`.
# TODO(future): add support for float8 all-gather with delayed scaling
# for activations and gradients.
config = Float8LinearConfig(emulate=True)

toy_model = ToyModel().to(device)
toy_model_fp8 = swap_linear_with_float8_linear(toy_model, emulate=True)
toy_model_fp8 = swap_linear_with_float8_linear(toy_model, config=config)

tp_model = copy.deepcopy(toy_model)
tp_model = swap_linear_with_float8_linear(tp_model, emulate=True)
tp_model = swap_linear_with_float8_linear(tp_model, config=config)
sp_model = copy.deepcopy(toy_model)
sp_model = swap_linear_with_float8_linear(sp_model, emulate=True)
sp_model = swap_linear_with_float8_linear(sp_model, config=config)

# vanilla TP
tp_model = parallelize_module(
Expand Down Expand Up @@ -222,7 +223,7 @@ def _test_fp8_mlp_tensor_parallelism_base(

# PrepareFloat8ModuleInput with specific submodule fqn
sp_model2 = copy.deepcopy(toy_model)
sp_model2 = swap_linear_with_float8_linear(sp_model2, emulate=True)
sp_model2 = swap_linear_with_float8_linear(sp_model2, config=config)

sp_model2 = parallelize_module(
sp_model2,
Expand Down
3 changes: 2 additions & 1 deletion test/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,14 @@ def fsdp_main(rank, world_size, args):
)
config = Float8LinearConfig(
cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight),
# TODO(future): delete this arg as it's always False
emulate=False,
)

# Note: we only iterate over `scaling_type_weight` because FSDP only interacts
# with weights.
swap_linear_with_float8_linear(
model_fp8,
emulate=False,
config=config,
)

Expand Down
11 changes: 6 additions & 5 deletions test/test_fsdp2/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,9 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
# requirement to use a smaller activation size
float8_linear_config = Float8LinearConfig(
enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather,
emulate=True,
)
swap_linear_with_float8_linear(model, emulate=True, config=float8_linear_config)
swap_linear_with_float8_linear(model, config=float8_linear_config)
model_unsharded_numel = sum(p.numel() for p in model.parameters())
model_sharded_numel = (model_unsharded_numel + 1) // 2
block_lin_weight_numel = 0
Expand Down Expand Up @@ -294,10 +295,10 @@ def test_weight_subclass_dynamic(self):
module_fp32 = self.init_single_module()
float8_linear_config = Float8LinearConfig(
enable_fsdp_fp8_all_gather=True,
emulate=True,
)
module = swap_linear_with_float8_linear(
module_fp32,
emulate=True,
config=float8_linear_config,
)
self.assertIsInstance(module.weight, tensor_cls)
Expand All @@ -311,7 +312,6 @@ def test_weight_subclass_dynamic(self):
module = self.init_multi_module()
module = swap_linear_with_float8_linear(
module,
emulate=True,
config=float8_linear_config,
)
for param_name, param in module.named_parameters():
Expand Down Expand Up @@ -517,12 +517,13 @@ def test_bf16_mp_fp8_dynamic_multi_parity(self):
# parameters, changing the numerics.
module = self.init_multi_module()
ref_module_bf16 = copy.deepcopy(module).to(torch.bfloat16)
float8_config = Float8LinearConfig(emulate=True)
ref_module_bf16 = swap_linear_with_float8_linear(
ref_module_bf16,
emulate=True,
config=float8_config,
)
ref_module_fp32 = copy.deepcopy(module).cuda()
module = swap_linear_with_float8_linear(module, emulate=True)
module = swap_linear_with_float8_linear(module, config=float8_config)
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
for mlp in module:
fully_shard(mlp, mp_policy=mp_policy)
Expand Down
2 changes: 1 addition & 1 deletion test/test_fsdp_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32):
cast_config_grad_output=Float8TensorCastConfig(
scaling_type=TensorScalingType.DELAYED
),
emulate=emulate,
)

m = nn.Sequential(
Expand All @@ -74,7 +75,6 @@ def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32):
)
swap_linear_with_float8_linear(
m,
emulate=emulate,
config=config,
)
return m
Expand Down
Loading