Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_f
m = Model(...)

# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(fqn: str, mod: torch.nn.Module):
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# don't convert the output module
if fqn == "output":
return False
Expand Down Expand Up @@ -91,9 +91,9 @@ from float8_experimental.float8_linear import TensorScalingType
# create model
m = Model(...)

# optional: configure for compatibility with FSDP. Note that workarounds
# optional: configure for compatibility with FSDP. Note that workarounds
# gated with config.enable_amax_init and
# config.enable_pre_and_post_forward are needed for
# config.enable_pre_and_post_forward are needed for
# autocast + compile + FSDP + float8 to work
from float8_experimental import Float8LinearConfig, TensorScalingType, Float8TensorCastConfig
config = Float8LinearConfig(
Expand Down
14 changes: 6 additions & 8 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def swap_linear_layers(
module: nn.Module,
from_float_func: Callable[[nn.Linear], nn.Linear],
*,
module_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
) -> Optional[nn.Module]:
"""
Generic function to swap linear layers in a module with a new type of linear layer.
Expand All @@ -74,13 +74,13 @@ def swap_linear_layers(
from_float_func: Function that accepts a linear layer and returns a new type of linear layer.
module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
that pass the filter function will be swapped. The inputs to the
filter function are the FQN and module instance.
filter function are the module instance, and the FQN.

Returns:
nn.Module: The modified module with swapped linear layers.
"""
if isinstance(module, nn.Linear) and (
module_filter_fn is None or module_filter_fn("", module)
module_filter_fn is None or module_filter_fn(module, "")
):
if len(list(module.children())) > 0:
raise AssertionError(
Expand Down Expand Up @@ -109,9 +109,7 @@ def post_order_traversal(
post_order_traversal(child_module, new_fqn, module)

if isinstance(module, nn.Linear) and (
# linear_layer_filter is None or linear_layer_filter(module)
module_filter_fn is None
or module_filter_fn(cur_fqn, module)
module_filter_fn is None or module_filter_fn(module, cur_fqn)
):
assert (
parent_module is not None
Expand All @@ -127,7 +125,7 @@ def post_order_traversal(
def swap_linear_with_float8_linear(
module: nn.Module,
*,
module_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
config: Float8LinearConfig = None,
) -> Optional[nn.Module]:
"""
Expand All @@ -137,7 +135,7 @@ def swap_linear_with_float8_linear(
module: Module to modify.
module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
that pass the filter function will be swapped. The inputs to the
filter function are the FQN and module instance.
filter function are the module instance and the FQN.
config (Float8LinearConfig): configuration for conversion to float8

Returns:
Expand Down
4 changes: 2 additions & 2 deletions float8_experimental/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def quantize_to_float8(
module: nn.Module,
quant_config: QuantConfig,
*,
module_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
use_fast_accum: bool = True,
) -> Optional[nn.Module]:
"""
Expand All @@ -228,7 +228,7 @@ def quantize_to_float8(
quant_config (QuantConfig): Quantization configuration for Float8 conversion.
module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
that pass the filter function will be swapped. The inputs to the
filter function are the FQN and module instance.
filter function are the module instance and the FQN.
use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True.

Returns:
Expand Down
4 changes: 2 additions & 2 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def __init__(self, dim: int):

size_limit = 32

def module_filter_fn(fqn, mod):
def module_filter_fn(mod, fqn):
return (
mod.in_features >= size_limit
and mod.out_features >= size_limit
Expand Down Expand Up @@ -682,7 +682,7 @@ def __init__(self, dim: int):
self.lin2 = nn.Linear(4 * dim, dim)

model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3))
module_filter_fn = lambda fqn, mod: fqn not in [
module_filter_fn = lambda mod, fqn: fqn not in [
"0.lin2",
"2.lin1",
]
Expand Down